1 // -*- mode: rust; -*-
2 //
3 // This file is part of curve25519-dalek.
4 // Copyright (c) 2016-2019 Isis Lovecruft, Henry de Valence
5 // See LICENSE for licensing information.
6 //
7 // Authors:
8 // - Isis Agora Lovecruft <isis@patternsinthevoid.net>
9 // - Henry de Valence <hdevalence@hdevalence.ca>
10 
11 // We allow non snake_case names because coordinates in projective space are
12 // traditionally denoted by the capitalisation of their respective
13 // counterparts in affine space.  Yeah, you heard me, rustc, I'm gonna have my
14 // affine and projective cakes and eat both of them too.
15 #![allow(non_snake_case)]
16 
17 //! An implementation of [Ristretto][ristretto_main], which provides a
18 //! prime-order group.
19 //!
20 //! # The Ristretto Group
21 //!
22 //! Ristretto is a modification of Mike Hamburg's Decaf scheme to work
23 //! with cofactor-\\(8\\) curves, such as Curve25519.
24 //!
25 //! The introduction of the Decaf paper, [_Decaf:
26 //! Eliminating cofactors through point
27 //! compression_](https://eprint.iacr.org/2015/673.pdf), notes that while
28 //! most cryptographic systems require a group of prime order, most
29 //! concrete implementations using elliptic curve groups fall short –
30 //! they either provide a group of prime order, but with incomplete or
31 //! variable-time addition formulae (for instance, most Weierstrass
32 //! models), or else they provide a fast and safe implementation of a
33 //! group whose order is not quite a prime \\(q\\), but \\(hq\\) for a
34 //! small cofactor \\(h\\) (for instance, Edwards curves, which have
35 //! cofactor at least \\(4\\)).
36 //!
37 //! This abstraction mismatch is commonly “handled” by pushing the
38 //! complexity upwards, adding ad-hoc protocol modifications.  But
39 //! these modifications require careful analysis and are a recurring
40 //! source of [vulnerabilities][cryptonote] and [design
41 //! complications][ed25519_hkd].
42 //!
43 //! Instead, Decaf (and Ristretto) use a quotient group to implement a
44 //! prime-order group using a non-prime-order curve.  This provides
45 //! the correct abstraction for cryptographic systems, while retaining
46 //! the speed and safety benefits of an Edwards curve.
47 //!
48 //! Decaf is named “after the procedure which divides the effect of
49 //! coffee by \\(4\\)”.  However, Curve25519 has a cofactor of
50 //! \\(8\\).  To eliminate its cofactor, Ristretto restricts further;
51 //! this [additional restriction][ristretto_coffee] gives the
52 //! _Ristretto_ encoding.
53 //!
54 //! More details on why Ristretto is necessary can be found in the
55 //! [Why Ristretto?][why_ristretto] section of the Ristretto website.
56 //!
57 //! Ristretto
58 //! points are provided in `curve25519-dalek` by the `RistrettoPoint`
59 //! struct.
60 //!
61 //! ## Encoding and Decoding
62 //!
63 //! Encoding is done by converting to and from a `CompressedRistretto`
64 //! struct, which is a typed wrapper around `[u8; 32]`.
65 //!
66 //! The encoding is not batchable, but it is possible to
67 //! double-and-encode in a batch using
68 //! `RistrettoPoint::double_and_compress_batch`.
69 //!
70 //! ## Equality Testing
71 //!
72 //! Testing equality of points on an Edwards curve in projective
73 //! coordinates requires an expensive inversion.  By contrast, equality
74 //! checking in the Ristretto group can be done in projective
75 //! coordinates without requiring an inversion, so it is much faster.
76 //!
77 //! The `RistrettoPoint` struct implements the
78 //! `subtle::ConstantTimeEq` trait for constant-time equality
79 //! checking, and the Rust `Eq` trait for variable-time equality
80 //! checking.
81 //!
82 //! ## Scalars
83 //!
84 //! Scalars are represented by the `Scalar` struct.  Each scalar has a
85 //! canonical representative mod the group order.  To attempt to load
86 //! a supposedly-canonical scalar, use
87 //! `Scalar::from_canonical_bytes()`. To check whether a
88 //! representative is canonical, use `Scalar::is_canonical()`.
89 //!
90 //! ## Scalar Multiplication
91 //!
92 //! Scalar multiplication on Ristretto points is provided by:
93 //!
94 //! * the `*` operator between a `Scalar` and a `RistrettoPoint`, which
95 //! performs constant-time variable-base scalar multiplication;
96 //!
97 //! * the `*` operator between a `Scalar` and a
98 //! `RistrettoBasepointTable`, which performs constant-time fixed-base
99 //! scalar multiplication;
100 //!
101 //! * an implementation of the
102 //! [`MultiscalarMul`](../traits/trait.MultiscalarMul.html) trait for
103 //! constant-time variable-base multiscalar multiplication;
104 //!
105 //! * an implementation of the
106 //! [`VartimeMultiscalarMul`](../traits/trait.VartimeMultiscalarMul.html)
107 //! trait for variable-time variable-base multiscalar multiplication;
108 //!
109 //! ## Random Points and Hashing to Ristretto
110 //!
111 //! The Ristretto group comes equipped with an Elligator map.  This is
112 //! used to implement
113 //!
114 //! * `RistrettoPoint::random()`, which generates random points from an
115 //! RNG;
116 //!
117 //! * `RistrettoPoint::from_hash()` and
118 //! `RistrettoPoint::hash_from_bytes()`, which perform hashing to the
119 //! group.
120 //!
121 //! The Elligator map itself is not currently exposed.
122 //!
123 //! ## Implementation
124 //!
125 //! The Decaf suggestion is to use a quotient group, such as \\(\mathcal
126 //! E / \mathcal E[4]\\) or \\(2 \mathcal E / \mathcal E[2] \\), to
127 //! implement a prime-order group using a non-prime-order curve.
128 //!
129 //! This requires only changing
130 //!
131 //! 1. the function for equality checking (so that two representatives
132 //!    of the same coset are considered equal);
133 //! 2. the function for encoding (so that two representatives of the
134 //!    same coset are encoded as identical bitstrings);
135 //! 3. the function for decoding (so that only the canonical encoding of
136 //!    a coset is accepted).
137 //!
138 //! Internally, each coset is represented by a curve point; two points
139 //! \\( P, Q \\) may represent the same coset in the same way that two
140 //! points with different \\(X,Y,Z\\) coordinates may represent the
141 //! same point.  The group operations are carried out with no overhead
142 //! using Edwards formulas.
143 //!
144 //! Notes on the details of the encoding can be found in the
145 //! [Details][ristretto_notes] section of the Ristretto website.
146 //!
147 //! [cryptonote]:
148 //! https://moderncrypto.org/mail-archive/curves/2017/000898.html
149 //! [ed25519_hkd]:
150 //! https://moderncrypto.org/mail-archive/curves/2017/000858.html
151 //! [ristretto_coffee]:
152 //! https://en.wikipedia.org/wiki/Ristretto
153 //! [ristretto_notes]:
154 //! https://ristretto.group/details/index.html
155 //! [why_ristretto]:
156 //! https://ristretto.group/why_ristretto.html
157 //! [ristretto_main]:
158 //! https://ristretto.group/
159 
160 use core::borrow::Borrow;
161 use core::fmt::Debug;
162 use core::iter::Sum;
163 use core::ops::{Add, Neg, Sub};
164 use core::ops::{AddAssign, SubAssign};
165 use core::ops::{Mul, MulAssign};
166 
167 use rand_core::{CryptoRng, RngCore};
168 
169 use digest::generic_array::typenum::U64;
170 use digest::Digest;
171 
172 use constants;
173 use field::FieldElement;
174 
175 use subtle::Choice;
176 use subtle::ConditionallySelectable;
177 use subtle::ConditionallyNegatable;
178 use subtle::ConstantTimeEq;
179 
180 use edwards::EdwardsBasepointTable;
181 use edwards::EdwardsPoint;
182 
183 #[allow(unused_imports)]
184 use prelude::*;
185 
186 use scalar::Scalar;
187 
188 use traits::Identity;
189 #[cfg(any(feature = "alloc", feature = "std"))]
190 use traits::{MultiscalarMul, VartimeMultiscalarMul, VartimePrecomputedMultiscalarMul};
191 
192 #[cfg(not(all(
193     feature = "simd_backend",
194     any(target_feature = "avx2", target_feature = "avx512ifma")
195 )))]
196 use backend::serial::scalar_mul;
197 #[cfg(all(
198     feature = "simd_backend",
199     any(target_feature = "avx2", target_feature = "avx512ifma")
200 ))]
201 use backend::vector::scalar_mul;
202 
203 // ------------------------------------------------------------------------
204 // Compressed points
205 // ------------------------------------------------------------------------
206 
207 /// A Ristretto point, in compressed wire format.
208 ///
209 /// The Ristretto encoding is canonical, so two points are equal if and
210 /// only if their encodings are equal.
211 #[derive(Copy, Clone, Eq, PartialEq)]
212 pub struct CompressedRistretto(pub [u8; 32]);
213 
214 impl ConstantTimeEq for CompressedRistretto {
ct_eq(&self, other: &CompressedRistretto) -> Choice215     fn ct_eq(&self, other: &CompressedRistretto) -> Choice {
216         self.as_bytes().ct_eq(other.as_bytes())
217     }
218 }
219 
220 impl CompressedRistretto {
221     /// Copy the bytes of this `CompressedRistretto`.
to_bytes(&self) -> [u8; 32]222     pub fn to_bytes(&self) -> [u8; 32] {
223         self.0
224     }
225 
226     /// View this `CompressedRistretto` as an array of bytes.
as_bytes(&self) -> &[u8; 32]227     pub fn as_bytes(&self) -> &[u8; 32] {
228         &self.0
229     }
230 
231     /// Construct a `CompressedRistretto` from a slice of bytes.
232     ///
233     /// # Panics
234     ///
235     /// If the input `bytes` slice does not have a length of 32.
from_slice(bytes: &[u8]) -> CompressedRistretto236     pub fn from_slice(bytes: &[u8]) -> CompressedRistretto {
237         let mut tmp = [0u8; 32];
238 
239         tmp.copy_from_slice(bytes);
240 
241         CompressedRistretto(tmp)
242     }
243 
244     /// Attempt to decompress to an `RistrettoPoint`.
245     ///
246     /// # Return
247     ///
248     /// - `Some(RistrettoPoint)` if `self` was the canonical encoding of a point;
249     ///
250     /// - `None` if `self` was not the canonical encoding of a point.
decompress(&self) -> Option<RistrettoPoint>251     pub fn decompress(&self) -> Option<RistrettoPoint> {
252         // Step 1. Check s for validity:
253         // 1.a) s must be 32 bytes (we get this from the type system)
254         // 1.b) s < p
255         // 1.c) s is nonnegative
256         //
257         // Our decoding routine ignores the high bit, so the only
258         // possible failure for 1.b) is if someone encodes s in 0..18
259         // as s+p in 2^255-19..2^255-1.  We can check this by
260         // converting back to bytes, and checking that we get the
261         // original input, since our encoding routine is canonical.
262 
263         let s = FieldElement::from_bytes(self.as_bytes());
264         let s_bytes_check = s.to_bytes();
265         let s_encoding_is_canonical =
266             &s_bytes_check[..].ct_eq(self.as_bytes());
267         let s_is_negative = s.is_negative();
268 
269         if s_encoding_is_canonical.unwrap_u8() == 0u8 || s_is_negative.unwrap_u8() == 1u8 {
270             return None;
271         }
272 
273         // Step 2.  Compute (X:Y:Z:T).
274         let one = FieldElement::one();
275         let ss = s.square();
276         let u1 = &one - &ss;      //  1 + as²
277         let u2 = &one + &ss;      //  1 - as²    where a=-1
278         let u2_sqr = u2.square(); // (1 - as²)²
279 
280         // v == ad(1+as²)² - (1-as²)²            where d=-121665/121666
281         let v = &(&(-&constants::EDWARDS_D) * &u1.square()) - &u2_sqr;
282 
283         let (ok, I) = (&v * &u2_sqr).invsqrt(); // 1/sqrt(v*u_2²)
284 
285         let Dx = &I * &u2;         // 1/sqrt(v)
286         let Dy = &I * &(&Dx * &v); // 1/u2
287 
288         // x == | 2s/sqrt(v) | == + sqrt(4s²/(ad(1+as²)² - (1-as²)²))
289         let mut x = &(&s + &s) * &Dx;
290         let x_neg = x.is_negative();
291         x.conditional_negate(x_neg);
292 
293         // y == (1-as²)/(1+as²)
294         let y = &u1 * &Dy;
295 
296         // t == ((1+as²) sqrt(4s²/(ad(1+as²)² - (1-as²)²)))/(1-as²)
297         let t = &x * &y;
298 
299         if ok.unwrap_u8() == 0u8 || t.is_negative().unwrap_u8() == 1u8 || y.is_zero().unwrap_u8() == 1u8 {
300             None
301         } else {
302             Some(RistrettoPoint(EdwardsPoint{X: x, Y: y, Z: one, T: t}))
303         }
304     }
305 }
306 
307 impl Identity for CompressedRistretto {
identity() -> CompressedRistretto308     fn identity() -> CompressedRistretto {
309         CompressedRistretto([0u8; 32])
310     }
311 }
312 
313 impl Default for CompressedRistretto {
default() -> CompressedRistretto314     fn default() -> CompressedRistretto {
315         CompressedRistretto::identity()
316     }
317 }
318 
319 // ------------------------------------------------------------------------
320 // Serde support
321 // ------------------------------------------------------------------------
322 // Serializes to and from `RistrettoPoint` directly, doing compression
323 // and decompression internally.  This means that users can create
324 // structs containing `RistrettoPoint`s and use Serde's derived
325 // serializers to serialize those structures.
326 
327 #[cfg(feature = "serde")]
328 use serde::{self, Serialize, Deserialize, Serializer, Deserializer};
329 #[cfg(feature = "serde")]
330 use serde::de::Visitor;
331 
332 #[cfg(feature = "serde")]
333 impl Serialize for RistrettoPoint {
serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error> where S: Serializer334     fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
335         where S: Serializer
336     {
337         use serde::ser::SerializeTuple;
338         let mut tup = serializer.serialize_tuple(32)?;
339         for byte in self.compress().as_bytes().iter() {
340             tup.serialize_element(byte)?;
341         }
342         tup.end()
343     }
344 }
345 
346 #[cfg(feature = "serde")]
347 impl Serialize for CompressedRistretto {
serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error> where S: Serializer348     fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
349         where S: Serializer
350     {
351         use serde::ser::SerializeTuple;
352         let mut tup = serializer.serialize_tuple(32)?;
353         for byte in self.as_bytes().iter() {
354             tup.serialize_element(byte)?;
355         }
356         tup.end()
357     }
358 }
359 
360 #[cfg(feature = "serde")]
361 impl<'de> Deserialize<'de> for RistrettoPoint {
deserialize<D>(deserializer: D) -> Result<Self, D::Error> where D: Deserializer<'de>362     fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
363         where D: Deserializer<'de>
364     {
365         struct RistrettoPointVisitor;
366 
367         impl<'de> Visitor<'de> for RistrettoPointVisitor {
368             type Value = RistrettoPoint;
369 
370             fn expecting(&self, formatter: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
371                 formatter.write_str("a valid point in Ristretto format")
372             }
373 
374             fn visit_seq<A>(self, mut seq: A) -> Result<RistrettoPoint, A::Error>
375                 where A: serde::de::SeqAccess<'de>
376             {
377                 let mut bytes = [0u8; 32];
378                 for i in 0..32 {
379                     bytes[i] = seq.next_element()?
380                         .ok_or(serde::de::Error::invalid_length(i, &"expected 32 bytes"))?;
381                 }
382                 CompressedRistretto(bytes)
383                     .decompress()
384                     .ok_or(serde::de::Error::custom("decompression failed"))
385             }
386         }
387 
388         deserializer.deserialize_tuple(32, RistrettoPointVisitor)
389     }
390 }
391 
392 #[cfg(feature = "serde")]
393 impl<'de> Deserialize<'de> for CompressedRistretto {
deserialize<D>(deserializer: D) -> Result<Self, D::Error> where D: Deserializer<'de>394     fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
395         where D: Deserializer<'de>
396     {
397         struct CompressedRistrettoVisitor;
398 
399         impl<'de> Visitor<'de> for CompressedRistrettoVisitor {
400             type Value = CompressedRistretto;
401 
402             fn expecting(&self, formatter: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
403                 formatter.write_str("32 bytes of data")
404             }
405 
406             fn visit_seq<A>(self, mut seq: A) -> Result<CompressedRistretto, A::Error>
407                 where A: serde::de::SeqAccess<'de>
408             {
409                 let mut bytes = [0u8; 32];
410                 for i in 0..32 {
411                     bytes[i] = seq.next_element()?
412                         .ok_or(serde::de::Error::invalid_length(i, &"expected 32 bytes"))?;
413                 }
414                 Ok(CompressedRistretto(bytes))
415             }
416         }
417 
418         deserializer.deserialize_tuple(32, CompressedRistrettoVisitor)
419     }
420 }
421 
422 // ------------------------------------------------------------------------
423 // Internal point representations
424 // ------------------------------------------------------------------------
425 
426 /// A `RistrettoPoint` represents a point in the Ristretto group for
427 /// Curve25519.  Ristretto, a variant of Decaf, constructs a
428 /// prime-order group as a quotient group of a subgroup of (the
429 /// Edwards form of) Curve25519.
430 ///
431 /// Internally, a `RistrettoPoint` is implemented as a wrapper type
432 /// around `EdwardsPoint`, with custom equality, compression, and
433 /// decompression routines to account for the quotient.  This means that
434 /// operations on `RistrettoPoint`s are exactly as fast as operations on
435 /// `EdwardsPoint`s.
436 ///
437 #[derive(Copy, Clone)]
438 pub struct RistrettoPoint(pub(crate) EdwardsPoint);
439 
440 impl RistrettoPoint {
441     /// Compress this point using the Ristretto encoding.
compress(&self) -> CompressedRistretto442     pub fn compress(&self) -> CompressedRistretto {
443         let mut X = self.0.X;
444         let mut Y = self.0.Y;
445         let Z = &self.0.Z;
446         let T = &self.0.T;
447 
448         let u1 = &(Z + &Y) * &(Z - &Y);
449         let u2 = &X * &Y;
450         // Ignore return value since this is always square
451         let (_, invsqrt) = (&u1 * &u2.square()).invsqrt();
452         let i1 = &invsqrt * &u1;
453         let i2 = &invsqrt * &u2;
454         let z_inv = &i1 * &(&i2 * T);
455         let mut den_inv = i2;
456 
457         let iX = &X * &constants::SQRT_M1;
458         let iY = &Y * &constants::SQRT_M1;
459         let ristretto_magic = &constants::INVSQRT_A_MINUS_D;
460         let enchanted_denominator = &i1 * ristretto_magic;
461 
462         let rotate = (T * &z_inv).is_negative();
463 
464         X.conditional_assign(&iY, rotate);
465         Y.conditional_assign(&iX, rotate);
466         den_inv.conditional_assign(&enchanted_denominator, rotate);
467 
468         Y.conditional_negate((&X * &z_inv).is_negative());
469 
470         let mut s = &den_inv * &(Z - &Y);
471         let s_is_negative = s.is_negative();
472         s.conditional_negate(s_is_negative);
473 
474         CompressedRistretto(s.to_bytes())
475     }
476 
477     /// Double-and-compress a batch of points.  The Ristretto encoding
478     /// is not batchable, since it requires an inverse square root.
479     ///
480     /// However, given input points \\( P\_1, \ldots, P\_n, \\)
481     /// it is possible to compute the encodings of their doubles \\(
482     /// \mathrm{enc}( [2]P\_1), \ldots, \mathrm{enc}( [2]P\_n ) \\)
483     /// in a batch.
484     ///
485     /// ```
486     /// # extern crate curve25519_dalek;
487     /// # use curve25519_dalek::ristretto::RistrettoPoint;
488     /// extern crate rand_core;
489     /// use rand_core::OsRng;
490     ///
491     /// # // Need fn main() here in comment so the doctest compiles
492     /// # // See https://doc.rust-lang.org/book/documentation.html#documentation-as-tests
493     /// # fn main() {
494     /// let mut rng = OsRng;
495     /// let points: Vec<RistrettoPoint> =
496     ///     (0..32).map(|_| RistrettoPoint::random(&mut rng)).collect();
497     ///
498     /// let compressed = RistrettoPoint::double_and_compress_batch(&points);
499     ///
500     /// for (P, P2_compressed) in points.iter().zip(compressed.iter()) {
501     ///     assert_eq!(*P2_compressed, (P + P).compress());
502     /// }
503     /// # }
504     /// ```
505     #[cfg(feature = "alloc")]
double_and_compress_batch<'a, I>(points: I) -> Vec<CompressedRistretto> where I: IntoIterator<Item = &'a RistrettoPoint>506     pub fn double_and_compress_batch<'a, I>(points: I) -> Vec<CompressedRistretto>
507         where I: IntoIterator<Item = &'a RistrettoPoint>
508     {
509         #[derive(Copy, Clone, Debug)]
510         struct BatchCompressState {
511             e: FieldElement,
512             f: FieldElement,
513             g: FieldElement,
514             h: FieldElement,
515             eg: FieldElement,
516             fh: FieldElement,
517         }
518 
519         impl BatchCompressState {
520             fn efgh(&self) -> FieldElement {
521                 &self.eg * &self.fh
522             }
523         }
524 
525         impl<'a> From<&'a RistrettoPoint> for BatchCompressState {
526             fn from(P: &'a RistrettoPoint) -> BatchCompressState {
527                 let XX = P.0.X.square();
528                 let YY = P.0.Y.square();
529                 let ZZ = P.0.Z.square();
530                 let dTT = &P.0.T.square() * &constants::EDWARDS_D;
531 
532                 let e = &P.0.X * &(&P.0.Y + &P.0.Y); // = 2*X*Y
533                 let f = &ZZ + &dTT;                  // = Z^2 + d*T^2
534                 let g = &YY + &XX;                   // = Y^2 - a*X^2
535                 let h = &ZZ - &dTT;                  // = Z^2 - d*T^2
536 
537                 let eg = &e * &g;
538                 let fh = &f * &h;
539 
540                 BatchCompressState{ e, f, g, h, eg, fh }
541             }
542         }
543 
544         let states: Vec<BatchCompressState> = points.into_iter().map(BatchCompressState::from).collect();
545 
546         let mut invs: Vec<FieldElement> = states.iter().map(|state| state.efgh()).collect();
547 
548         FieldElement::batch_invert(&mut invs[..]);
549 
550         states.iter().zip(invs.iter()).map(|(state, inv): (&BatchCompressState, &FieldElement)| {
551             let Zinv = &state.eg * &inv;
552             let Tinv = &state.fh * &inv;
553 
554             let mut magic = constants::INVSQRT_A_MINUS_D;
555 
556             let negcheck1 = (&state.eg * &Zinv).is_negative();
557 
558             let mut e = state.e;
559             let mut g = state.g;
560             let mut h = state.h;
561 
562             let minus_e = -&e;
563             let f_times_sqrta = &state.f * &constants::SQRT_M1;
564 
565             e.conditional_assign(&state.g,       negcheck1);
566             g.conditional_assign(&minus_e,       negcheck1);
567             h.conditional_assign(&f_times_sqrta, negcheck1);
568 
569             magic.conditional_assign(&constants::SQRT_M1, negcheck1);
570 
571             let negcheck2 = (&(&h * &e) * &Zinv).is_negative();
572 
573             g.conditional_negate(negcheck2);
574 
575             let mut s = &(&h - &g) * &(&magic * &(&g * &Tinv));
576 
577             let s_is_negative = s.is_negative();
578             s.conditional_negate(s_is_negative);
579 
580             CompressedRistretto(s.to_bytes())
581         }).collect()
582     }
583 
584 
585     /// Return the coset self + E[4], for debugging.
coset4(&self) -> [EdwardsPoint; 4]586     fn coset4(&self) -> [EdwardsPoint; 4] {
587         [  self.0
588         , &self.0 + &constants::EIGHT_TORSION[2]
589         , &self.0 + &constants::EIGHT_TORSION[4]
590         , &self.0 + &constants::EIGHT_TORSION[6]
591         ]
592     }
593 
594     /// Computes the Ristretto Elligator map.
595     ///
596     /// # Note
597     ///
598     /// This method is not public because it's just used for hashing
599     /// to a point -- proper elligator support is deferred for now.
elligator_ristretto_flavor(r_0: &FieldElement) -> RistrettoPoint600     pub fn elligator_ristretto_flavor(r_0: &FieldElement) -> RistrettoPoint {
601         let i = &constants::SQRT_M1;
602         let d = &constants::EDWARDS_D;
603         let one_minus_d_sq = &constants::ONE_MINUS_EDWARDS_D_SQUARED;
604         let d_minus_one_sq = &constants::EDWARDS_D_MINUS_ONE_SQUARED;
605         let mut c = constants::MINUS_ONE;
606 
607         let one = FieldElement::one();
608 
609         let r = i * &r_0.square();
610         let N_s = &(&r + &one) * &one_minus_d_sq;
611         let D = &(&c - &(d * &r)) * &(&r + d);
612 
613         let (Ns_D_is_sq, mut s) = FieldElement::sqrt_ratio_i(&N_s, &D);
614         let mut s_prime = &s * r_0;
615         let s_prime_is_pos = !s_prime.is_negative();
616         s_prime.conditional_negate(s_prime_is_pos);
617 
618         s.conditional_assign(&s_prime, !Ns_D_is_sq);
619         c.conditional_assign(&r, !Ns_D_is_sq);
620 
621         let N_t = &(&(&c * &(&r - &one)) * &d_minus_one_sq) - &D;
622         let s_sq = s.square();
623 
624         use backend::serial::curve_models::CompletedPoint;
625 
626         // The conversion from W_i is exactly the conversion from P1xP1.
627         RistrettoPoint(CompletedPoint{
628             X: &(&s + &s) * &D,
629             Z: &N_t * &constants::SQRT_AD_MINUS_ONE,
630             Y: &FieldElement::one() - &s_sq,
631             T: &FieldElement::one() + &s_sq,
632         }.to_extended())
633     }
634 
635     /// Return a `RistrettoPoint` chosen uniformly at random using a user-provided RNG.
636     ///
637     /// # Inputs
638     ///
639     /// * `rng`: any RNG which implements the `RngCore + CryptoRng` interface.
640     ///
641     /// # Returns
642     ///
643     /// A random element of the Ristretto group.
644     ///
645     /// # Implementation
646     ///
647     /// Uses the Ristretto-flavoured Elligator 2 map, so that the
648     /// discrete log of the output point with respect to any other
649     /// point should be unknown.  The map is applied twice and the
650     /// results are added, to ensure a uniform distribution.
random<R: RngCore + CryptoRng>(rng: &mut R) -> Self651     pub fn random<R: RngCore + CryptoRng>(rng: &mut R) -> Self {
652         let mut uniform_bytes = [0u8; 64];
653         rng.fill_bytes(&mut uniform_bytes);
654 
655         RistrettoPoint::from_uniform_bytes(&uniform_bytes)
656     }
657 
658     /// Hash a slice of bytes into a `RistrettoPoint`.
659     ///
660     /// Takes a type parameter `D`, which is any `Digest` producing 64
661     /// bytes of output.
662     ///
663     /// Convenience wrapper around `from_hash`.
664     ///
665     /// # Implementation
666     ///
667     /// Uses the Ristretto-flavoured Elligator 2 map, so that the
668     /// discrete log of the output point with respect to any other
669     /// point should be unknown.  The map is applied twice and the
670     /// results are added, to ensure a uniform distribution.
671     ///
672     /// # Example
673     ///
674     /// ```
675     /// # extern crate curve25519_dalek;
676     /// # use curve25519_dalek::ristretto::RistrettoPoint;
677     /// extern crate sha2;
678     /// use sha2::Sha512;
679     ///
680     /// # // Need fn main() here in comment so the doctest compiles
681     /// # // See https://doc.rust-lang.org/book/documentation.html#documentation-as-tests
682     /// # fn main() {
683     /// let msg = "To really appreciate architecture, you may even need to commit a murder";
684     /// let P = RistrettoPoint::hash_from_bytes::<Sha512>(msg.as_bytes());
685     /// # }
686     /// ```
687     ///
hash_from_bytes<D>(input: &[u8]) -> RistrettoPoint where D: Digest<OutputSize = U64> + Default688     pub fn hash_from_bytes<D>(input: &[u8]) -> RistrettoPoint
689         where D: Digest<OutputSize = U64> + Default
690     {
691         let mut hash = D::default();
692         hash.input(input);
693         RistrettoPoint::from_hash(hash)
694     }
695 
696     /// Construct a `RistrettoPoint` from an existing `Digest` instance.
697     ///
698     /// Use this instead of `hash_from_bytes` if it is more convenient
699     /// to stream data into the `Digest` than to pass a single byte
700     /// slice.
from_hash<D>(hash: D) -> RistrettoPoint where D: Digest<OutputSize = U64> + Default701     pub fn from_hash<D>(hash: D) -> RistrettoPoint
702         where D: Digest<OutputSize = U64> + Default
703     {
704         // dealing with generic arrays is clumsy, until const generics land
705         let output = hash.result();
706         let mut output_bytes = [0u8; 64];
707         output_bytes.copy_from_slice(&output.as_slice());
708 
709         RistrettoPoint::from_uniform_bytes(&output_bytes)
710     }
711 
712     /// Construct a `RistrettoPoint` from 64 bytes of data.
713     ///
714     /// If the input bytes are uniformly distributed, the resulting
715     /// point will be uniformly distributed over the group, and its
716     /// discrete log with respect to other points should be unknown.
717     ///
718     /// # Implementation
719     ///
720     /// This function splits the input array into two 32-byte halves,
721     /// takes the low 255 bits of each half mod p, applies the
722     /// Ristretto-flavored Elligator map to each, and adds the results.
from_uniform_bytes(bytes: &[u8; 64]) -> RistrettoPoint723     pub fn from_uniform_bytes(bytes: &[u8; 64]) -> RistrettoPoint {
724         let mut r_1_bytes = [0u8; 32];
725         r_1_bytes.copy_from_slice(&bytes[0..32]);
726         let r_1 = FieldElement::from_bytes(&r_1_bytes);
727         let R_1 = RistrettoPoint::elligator_ristretto_flavor(&r_1);
728 
729         let mut r_2_bytes = [0u8; 32];
730         r_2_bytes.copy_from_slice(&bytes[32..64]);
731         let r_2 = FieldElement::from_bytes(&r_2_bytes);
732         let R_2 = RistrettoPoint::elligator_ristretto_flavor(&r_2);
733 
734         // Applying Elligator twice and adding the results ensures a
735         // uniform distribution.
736         &R_1 + &R_2
737     }
738 }
739 
740 impl Identity for RistrettoPoint {
identity() -> RistrettoPoint741     fn identity() -> RistrettoPoint {
742         RistrettoPoint(EdwardsPoint::identity())
743     }
744 }
745 
746 impl Default for RistrettoPoint {
default() -> RistrettoPoint747     fn default() -> RistrettoPoint {
748         RistrettoPoint::identity()
749     }
750 }
751 
752 // ------------------------------------------------------------------------
753 // Equality
754 // ------------------------------------------------------------------------
755 
756 impl PartialEq for RistrettoPoint {
eq(&self, other: &RistrettoPoint) -> bool757     fn eq(&self, other: &RistrettoPoint) -> bool {
758         self.ct_eq(other).unwrap_u8() == 1u8
759     }
760 }
761 
762 impl ConstantTimeEq for RistrettoPoint {
763     /// Test equality between two `RistrettoPoint`s.
764     ///
765     /// # Returns
766     ///
767     /// * `Choice(1)` if the two `RistrettoPoint`s are equal;
768     /// * `Choice(0)` otherwise.
ct_eq(&self, other: &RistrettoPoint) -> Choice769     fn ct_eq(&self, other: &RistrettoPoint) -> Choice {
770         let X1Y2 = &self.0.X * &other.0.Y;
771         let Y1X2 = &self.0.Y * &other.0.X;
772         let X1X2 = &self.0.X * &other.0.X;
773         let Y1Y2 = &self.0.Y * &other.0.Y;
774 
775         X1Y2.ct_eq(&Y1X2) | X1X2.ct_eq(&Y1Y2)
776     }
777 }
778 
779 impl Eq for RistrettoPoint {}
780 
781 // ------------------------------------------------------------------------
782 // Arithmetic
783 // ------------------------------------------------------------------------
784 
785 impl<'a, 'b> Add<&'b RistrettoPoint> for &'a RistrettoPoint {
786     type Output = RistrettoPoint;
787 
add(self, other: &'b RistrettoPoint) -> RistrettoPoint788     fn add(self, other: &'b RistrettoPoint) -> RistrettoPoint {
789         RistrettoPoint(&self.0 + &other.0)
790     }
791 }
792 
793 define_add_variants!(LHS = RistrettoPoint, RHS = RistrettoPoint, Output = RistrettoPoint);
794 
795 impl<'b> AddAssign<&'b RistrettoPoint> for RistrettoPoint {
add_assign(&mut self, _rhs: &RistrettoPoint)796     fn add_assign(&mut self, _rhs: &RistrettoPoint) {
797         *self = (self as &RistrettoPoint) + _rhs;
798     }
799 }
800 
801 define_add_assign_variants!(LHS = RistrettoPoint, RHS = RistrettoPoint);
802 
803 impl<'a, 'b> Sub<&'b RistrettoPoint> for &'a RistrettoPoint {
804     type Output = RistrettoPoint;
805 
sub(self, other: &'b RistrettoPoint) -> RistrettoPoint806     fn sub(self, other: &'b RistrettoPoint) -> RistrettoPoint {
807         RistrettoPoint(&self.0 - &other.0)
808     }
809 }
810 
811 define_sub_variants!(LHS = RistrettoPoint, RHS = RistrettoPoint, Output = RistrettoPoint);
812 
813 impl<'b> SubAssign<&'b RistrettoPoint> for RistrettoPoint {
sub_assign(&mut self, _rhs: &RistrettoPoint)814     fn sub_assign(&mut self, _rhs: &RistrettoPoint) {
815         *self = (self as &RistrettoPoint) - _rhs;
816     }
817 }
818 
819 define_sub_assign_variants!(LHS = RistrettoPoint, RHS = RistrettoPoint);
820 
821 impl<T> Sum<T> for RistrettoPoint
822 where
823     T: Borrow<RistrettoPoint>
824 {
sum<I>(iter: I) -> Self where I: Iterator<Item = T>825     fn sum<I>(iter: I) -> Self
826     where
827         I: Iterator<Item = T>
828     {
829         iter.fold(RistrettoPoint::identity(), |acc, item| acc + item.borrow())
830     }
831 }
832 
833 impl<'a> Neg for &'a RistrettoPoint {
834     type Output = RistrettoPoint;
835 
neg(self) -> RistrettoPoint836     fn neg(self) -> RistrettoPoint {
837         RistrettoPoint(-&self.0)
838     }
839 }
840 
841 impl Neg for RistrettoPoint {
842     type Output = RistrettoPoint;
843 
neg(self) -> RistrettoPoint844     fn neg(self) -> RistrettoPoint {
845         -&self
846     }
847 }
848 
849 impl<'b> MulAssign<&'b Scalar> for RistrettoPoint {
mul_assign(&mut self, scalar: &'b Scalar)850     fn mul_assign(&mut self, scalar: &'b Scalar) {
851         let result = (self as &RistrettoPoint) * scalar;
852         *self = result;
853     }
854 }
855 
856 impl<'a, 'b> Mul<&'b Scalar> for &'a RistrettoPoint {
857     type Output = RistrettoPoint;
858     /// Scalar multiplication: compute `scalar * self`.
mul(self, scalar: &'b Scalar) -> RistrettoPoint859     fn mul(self, scalar: &'b Scalar) -> RistrettoPoint {
860         RistrettoPoint(self.0 * scalar)
861     }
862 }
863 
864 impl<'a, 'b> Mul<&'b RistrettoPoint> for &'a Scalar {
865     type Output = RistrettoPoint;
866 
867     /// Scalar multiplication: compute `self * scalar`.
mul(self, point: &'b RistrettoPoint) -> RistrettoPoint868     fn mul(self, point: &'b RistrettoPoint) -> RistrettoPoint {
869         RistrettoPoint(self * point.0)
870     }
871 }
872 
873 define_mul_assign_variants!(LHS = RistrettoPoint, RHS = Scalar);
874 
875 define_mul_variants!(LHS = RistrettoPoint, RHS = Scalar, Output = RistrettoPoint);
876 define_mul_variants!(LHS = Scalar, RHS = RistrettoPoint, Output = RistrettoPoint);
877 
878 // ------------------------------------------------------------------------
879 // Multiscalar Multiplication impls
880 // ------------------------------------------------------------------------
881 
882 // These use iterator combinators to unwrap the underlying points and
883 // forward to the EdwardsPoint implementations.
884 
885 #[cfg(feature = "alloc")]
886 impl MultiscalarMul for RistrettoPoint {
887     type Point = RistrettoPoint;
888 
multiscalar_mul<I, J>(scalars: I, points: J) -> RistrettoPoint where I: IntoIterator, I::Item: Borrow<Scalar>, J: IntoIterator, J::Item: Borrow<RistrettoPoint>,889     fn multiscalar_mul<I, J>(scalars: I, points: J) -> RistrettoPoint
890     where
891         I: IntoIterator,
892         I::Item: Borrow<Scalar>,
893         J: IntoIterator,
894         J::Item: Borrow<RistrettoPoint>,
895     {
896         let extended_points = points.into_iter().map(|P| P.borrow().0);
897         RistrettoPoint(
898             EdwardsPoint::multiscalar_mul(scalars, extended_points)
899         )
900     }
901 }
902 
903 #[cfg(feature = "alloc")]
904 impl VartimeMultiscalarMul for RistrettoPoint {
905     type Point = RistrettoPoint;
906 
optional_multiscalar_mul<I, J>(scalars: I, points: J) -> Option<RistrettoPoint> where I: IntoIterator, I::Item: Borrow<Scalar>, J: IntoIterator<Item = Option<RistrettoPoint>>,907     fn optional_multiscalar_mul<I, J>(scalars: I, points: J) -> Option<RistrettoPoint>
908     where
909         I: IntoIterator,
910         I::Item: Borrow<Scalar>,
911         J: IntoIterator<Item = Option<RistrettoPoint>>,
912     {
913         let extended_points = points.into_iter().map(|opt_P| opt_P.map(|P| P.borrow().0));
914 
915         EdwardsPoint::optional_multiscalar_mul(scalars, extended_points).map(RistrettoPoint)
916     }
917 }
918 
919 /// Precomputation for variable-time multiscalar multiplication with `RistrettoPoint`s.
920 // This wraps the inner implementation in a facade type so that we can
921 // decouple stability of the inner type from the stability of the
922 // outer type.
923 #[cfg(feature = "alloc")]
924 pub struct VartimeRistrettoPrecomputation(scalar_mul::precomputed_straus::VartimePrecomputedStraus);
925 
926 #[cfg(feature = "alloc")]
927 impl VartimePrecomputedMultiscalarMul for VartimeRistrettoPrecomputation {
928     type Point = RistrettoPoint;
929 
new<I>(static_points: I) -> Self where I: IntoIterator, I::Item: Borrow<Self::Point>,930     fn new<I>(static_points: I) -> Self
931     where
932         I: IntoIterator,
933         I::Item: Borrow<Self::Point>,
934     {
935         Self(
936             scalar_mul::precomputed_straus::VartimePrecomputedStraus::new(
937                 static_points.into_iter().map(|P| P.borrow().0),
938             ),
939         )
940     }
941 
optional_mixed_multiscalar_mul<I, J, K>( &self, static_scalars: I, dynamic_scalars: J, dynamic_points: K, ) -> Option<Self::Point> where I: IntoIterator, I::Item: Borrow<Scalar>, J: IntoIterator, J::Item: Borrow<Scalar>, K: IntoIterator<Item = Option<Self::Point>>,942     fn optional_mixed_multiscalar_mul<I, J, K>(
943         &self,
944         static_scalars: I,
945         dynamic_scalars: J,
946         dynamic_points: K,
947     ) -> Option<Self::Point>
948     where
949         I: IntoIterator,
950         I::Item: Borrow<Scalar>,
951         J: IntoIterator,
952         J::Item: Borrow<Scalar>,
953         K: IntoIterator<Item = Option<Self::Point>>,
954     {
955         self.0
956             .optional_mixed_multiscalar_mul(
957                 static_scalars,
958                 dynamic_scalars,
959                 dynamic_points.into_iter().map(|P_opt| P_opt.map(|P| P.0)),
960             )
961             .map(RistrettoPoint)
962     }
963 }
964 
965 impl RistrettoPoint {
966     /// Compute \\(aA + bB\\) in variable time, where \\(B\\) is the
967     /// Ristretto basepoint.
vartime_double_scalar_mul_basepoint( a: &Scalar, A: &RistrettoPoint, b: &Scalar, ) -> RistrettoPoint968     pub fn vartime_double_scalar_mul_basepoint(
969         a: &Scalar,
970         A: &RistrettoPoint,
971         b: &Scalar,
972     ) -> RistrettoPoint {
973         RistrettoPoint(
974             EdwardsPoint::vartime_double_scalar_mul_basepoint(a, &A.0, b)
975         )
976     }
977 }
978 
979 /// A precomputed table of multiples of a basepoint, used to accelerate
980 /// scalar multiplication.
981 ///
982 /// A precomputed table of multiples of the Ristretto basepoint is
983 /// available in the `constants` module:
984 /// ```
985 /// use curve25519_dalek::constants;
986 /// use curve25519_dalek::scalar::Scalar;
987 ///
988 /// let a = Scalar::from(87329482u64);
989 /// let P = &a * &constants::RISTRETTO_BASEPOINT_TABLE;
990 /// ```
991 #[derive(Clone)]
992 pub struct RistrettoBasepointTable(pub(crate) EdwardsBasepointTable);
993 
994 impl<'a, 'b> Mul<&'b Scalar> for &'a RistrettoBasepointTable {
995     type Output = RistrettoPoint;
996 
mul(self, scalar: &'b Scalar) -> RistrettoPoint997     fn mul(self, scalar: &'b Scalar) -> RistrettoPoint {
998         RistrettoPoint(&self.0 * scalar)
999     }
1000 }
1001 
1002 impl<'a, 'b> Mul<&'a RistrettoBasepointTable> for &'b Scalar {
1003     type Output = RistrettoPoint;
1004 
mul(self, basepoint_table: &'a RistrettoBasepointTable) -> RistrettoPoint1005     fn mul(self, basepoint_table: &'a RistrettoBasepointTable) -> RistrettoPoint {
1006         RistrettoPoint(self * &basepoint_table.0)
1007     }
1008 }
1009 
1010 impl RistrettoBasepointTable {
1011     /// Create a precomputed table of multiples of the given `basepoint`.
create(basepoint: &RistrettoPoint) -> RistrettoBasepointTable1012     pub fn create(basepoint: &RistrettoPoint) -> RistrettoBasepointTable {
1013         RistrettoBasepointTable(EdwardsBasepointTable::create(&basepoint.0))
1014     }
1015 
1016     /// Get the basepoint for this table as a `RistrettoPoint`.
basepoint(&self) -> RistrettoPoint1017     pub fn basepoint(&self) -> RistrettoPoint {
1018         RistrettoPoint(self.0.basepoint())
1019     }
1020 }
1021 
1022 // ------------------------------------------------------------------------
1023 // Constant-time conditional selection
1024 // ------------------------------------------------------------------------
1025 
1026 impl ConditionallySelectable for RistrettoPoint {
1027     /// Conditionally select between `self` and `other`.
1028     ///
1029     /// # Example
1030     ///
1031     /// ```
1032     /// # extern crate subtle;
1033     /// # extern crate curve25519_dalek;
1034     /// #
1035     /// use subtle::ConditionallySelectable;
1036     /// use subtle::Choice;
1037     /// #
1038     /// # use curve25519_dalek::traits::Identity;
1039     /// # use curve25519_dalek::ristretto::RistrettoPoint;
1040     /// # use curve25519_dalek::constants;
1041     /// # fn main() {
1042     ///
1043     /// let A = RistrettoPoint::identity();
1044     /// let B = constants::RISTRETTO_BASEPOINT_POINT;
1045     ///
1046     /// let mut P = A;
1047     ///
1048     /// P = RistrettoPoint::conditional_select(&A, &B, Choice::from(0));
1049     /// assert_eq!(P, A);
1050     /// P = RistrettoPoint::conditional_select(&A, &B, Choice::from(1));
1051     /// assert_eq!(P, B);
1052     /// # }
1053     /// ```
conditional_select( a: &RistrettoPoint, b: &RistrettoPoint, choice: Choice, ) -> RistrettoPoint1054     fn conditional_select(
1055         a: &RistrettoPoint,
1056         b: &RistrettoPoint,
1057         choice: Choice,
1058     ) -> RistrettoPoint {
1059         RistrettoPoint(EdwardsPoint::conditional_select(&a.0, &b.0, choice))
1060     }
1061 }
1062 
1063 // ------------------------------------------------------------------------
1064 // Debug traits
1065 // ------------------------------------------------------------------------
1066 
1067 impl Debug for CompressedRistretto {
fmt(&self, f: &mut ::core::fmt::Formatter) -> ::core::fmt::Result1068     fn fmt(&self, f: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
1069         write!(f, "CompressedRistretto: {:?}", self.as_bytes())
1070     }
1071 }
1072 
1073 impl Debug for RistrettoPoint {
fmt(&self, f: &mut ::core::fmt::Formatter) -> ::core::fmt::Result1074     fn fmt(&self, f: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
1075         let coset = self.coset4();
1076         write!(f, "RistrettoPoint: coset \n{:?}\n{:?}\n{:?}\n{:?}",
1077                coset[0], coset[1], coset[2], coset[3])
1078     }
1079 }
1080 
1081 // ------------------------------------------------------------------------
1082 // Tests
1083 // ------------------------------------------------------------------------
1084 
1085 #[cfg(test)]
1086 mod test {
1087     use rand_core::OsRng;
1088 
1089     use scalar::Scalar;
1090     use constants;
1091     use edwards::CompressedEdwardsY;
1092     use traits::{Identity};
1093     use super::*;
1094 
1095     #[test]
1096     #[cfg(feature = "serde")]
serde_bincode_basepoint_roundtrip()1097     fn serde_bincode_basepoint_roundtrip() {
1098         use bincode;
1099 
1100         let encoded = bincode::serialize(&constants::RISTRETTO_BASEPOINT_POINT).unwrap();
1101         let enc_compressed = bincode::serialize(&constants::RISTRETTO_BASEPOINT_COMPRESSED).unwrap();
1102         assert_eq!(encoded, enc_compressed);
1103 
1104         // Check that the encoding is 32 bytes exactly
1105         assert_eq!(encoded.len(), 32);
1106 
1107         let dec_uncompressed: RistrettoPoint = bincode::deserialize(&encoded).unwrap();
1108         let dec_compressed: CompressedRistretto = bincode::deserialize(&encoded).unwrap();
1109 
1110         assert_eq!(dec_uncompressed, constants::RISTRETTO_BASEPOINT_POINT);
1111         assert_eq!(dec_compressed, constants::RISTRETTO_BASEPOINT_COMPRESSED);
1112 
1113         // Check that the encoding itself matches the usual one
1114         let raw_bytes = constants::RISTRETTO_BASEPOINT_COMPRESSED.as_bytes();
1115         let bp: RistrettoPoint = bincode::deserialize(raw_bytes).unwrap();
1116         assert_eq!(bp, constants::RISTRETTO_BASEPOINT_POINT);
1117     }
1118 
1119     #[test]
scalarmult_ristrettopoint_works_both_ways()1120     fn scalarmult_ristrettopoint_works_both_ways() {
1121         let P = constants::RISTRETTO_BASEPOINT_POINT;
1122         let s = Scalar::from(999u64);
1123 
1124         let P1 = &P * &s;
1125         let P2 = &s * &P;
1126 
1127         assert!(P1.compress().as_bytes() == P2.compress().as_bytes());
1128     }
1129 
1130     #[test]
impl_sum()1131     fn impl_sum() {
1132 
1133         // Test that sum works for non-empty iterators
1134         let BASE = constants::RISTRETTO_BASEPOINT_POINT;
1135 
1136         let s1 = Scalar::from(999u64);
1137         let P1 = &BASE * &s1;
1138 
1139         let s2 = Scalar::from(333u64);
1140         let P2 = &BASE * &s2;
1141 
1142         let vec = vec![P1.clone(), P2.clone()];
1143         let sum: RistrettoPoint = vec.iter().sum();
1144 
1145         assert_eq!(sum, P1 + P2);
1146 
1147         // Test that sum works for the empty iterator
1148         let empty_vector: Vec<RistrettoPoint> = vec![];
1149         let sum: RistrettoPoint = empty_vector.iter().sum();
1150 
1151         assert_eq!(sum, RistrettoPoint::identity());
1152 
1153         // Test that sum works on owning iterators
1154         let s = Scalar::from(2u64);
1155         let mapped = vec.iter().map(|x| x * s);
1156         let sum: RistrettoPoint = mapped.sum();
1157 
1158         assert_eq!(sum, &P1 * &s + &P2 * &s);
1159     }
1160 
1161     #[test]
decompress_negative_s_fails()1162     fn decompress_negative_s_fails() {
1163         // constants::d is neg, so decompression should fail as |d| != d.
1164         let bad_compressed = CompressedRistretto(constants::EDWARDS_D.to_bytes());
1165         assert!(bad_compressed.decompress().is_none());
1166     }
1167 
1168     #[test]
decompress_id()1169     fn decompress_id() {
1170         let compressed_id = CompressedRistretto::identity();
1171         let id = compressed_id.decompress().unwrap();
1172         let mut identity_in_coset = false;
1173         for P in &id.coset4() {
1174             if P.compress() == CompressedEdwardsY::identity() {
1175                 identity_in_coset = true;
1176             }
1177         }
1178         assert!(identity_in_coset);
1179     }
1180 
1181     #[test]
compress_id()1182     fn compress_id() {
1183         let id = RistrettoPoint::identity();
1184         assert_eq!(id.compress(), CompressedRistretto::identity());
1185     }
1186 
1187     #[test]
basepoint_roundtrip()1188     fn basepoint_roundtrip() {
1189         let bp_compressed_ristretto = constants::RISTRETTO_BASEPOINT_POINT.compress();
1190         let bp_recaf = bp_compressed_ristretto.decompress().unwrap().0;
1191         // Check that bp_recaf differs from bp by a point of order 4
1192         let diff = &constants::RISTRETTO_BASEPOINT_POINT.0 - &bp_recaf;
1193         let diff4 = diff.mul_by_pow_2(2);
1194         assert_eq!(diff4.compress(), CompressedEdwardsY::identity());
1195     }
1196 
1197     #[test]
encodings_of_small_multiples_of_basepoint()1198     fn encodings_of_small_multiples_of_basepoint() {
1199         // Table of encodings of i*basepoint
1200         // Generated using ristretto.sage
1201         let compressed = [
1202             CompressedRistretto([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]),
1203             CompressedRistretto([226, 242, 174, 10, 106, 188, 78, 113, 168, 132, 169, 97, 197, 0, 81, 95, 88, 227, 11, 106, 165, 130, 221, 141, 182, 166, 89, 69, 224, 141, 45, 118]),
1204             CompressedRistretto([106, 73, 50, 16, 247, 73, 156, 209, 127, 236, 181, 16, 174, 12, 234, 35, 161, 16, 232, 213, 185, 1, 248, 172, 173, 211, 9, 92, 115, 163, 185, 25]),
1205             CompressedRistretto([148, 116, 31, 93, 93, 82, 117, 94, 206, 79, 35, 240, 68, 238, 39, 213, 209, 234, 30, 43, 209, 150, 180, 98, 22, 107, 22, 21, 42, 157, 2, 89]),
1206             CompressedRistretto([218, 128, 134, 39, 115, 53, 139, 70, 111, 250, 223, 224, 179, 41, 58, 179, 217, 253, 83, 197, 234, 108, 149, 83, 88, 245, 104, 50, 45, 175, 106, 87]),
1207             CompressedRistretto([232, 130, 177, 49, 1, 107, 82, 193, 211, 51, 112, 128, 24, 124, 247, 104, 66, 62, 252, 203, 181, 23, 187, 73, 90, 184, 18, 196, 22, 15, 244, 78]),
1208             CompressedRistretto([246, 71, 70, 211, 201, 43, 19, 5, 14, 216, 216, 2, 54, 167, 240, 0, 124, 59, 63, 150, 47, 91, 167, 147, 209, 154, 96, 30, 187, 29, 244, 3]),
1209             CompressedRistretto([68, 245, 53, 32, 146, 110, 200, 31, 189, 90, 56, 120, 69, 190, 183, 223, 133, 169, 106, 36, 236, 225, 135, 56, 189, 207, 166, 167, 130, 42, 23, 109]),
1210             CompressedRistretto([144, 50, 147, 216, 242, 40, 126, 190, 16, 226, 55, 77, 193, 165, 62, 11, 200, 135, 229, 146, 105, 159, 2, 208, 119, 213, 38, 60, 221, 85, 96, 28]),
1211             CompressedRistretto([2, 98, 42, 206, 143, 115, 3, 163, 28, 175, 198, 63, 143, 196, 143, 220, 22, 225, 200, 200, 210, 52, 178, 240, 214, 104, 82, 130, 169, 7, 96, 49]),
1212             CompressedRistretto([32, 112, 111, 215, 136, 178, 114, 10, 30, 210, 165, 218, 212, 149, 43, 1, 244, 19, 188, 240, 231, 86, 77, 232, 205, 200, 22, 104, 158, 45, 185, 95]),
1213             CompressedRistretto([188, 232, 63, 139, 165, 221, 47, 165, 114, 134, 76, 36, 186, 24, 16, 249, 82, 43, 198, 0, 74, 254, 149, 135, 122, 199, 50, 65, 202, 253, 171, 66]),
1214             CompressedRistretto([228, 84, 158, 225, 107, 154, 160, 48, 153, 202, 32, 140, 103, 173, 175, 202, 250, 76, 63, 62, 78, 83, 3, 222, 96, 38, 227, 202, 143, 248, 68, 96]),
1215             CompressedRistretto([170, 82, 224, 0, 223, 46, 22, 245, 95, 177, 3, 47, 195, 59, 196, 39, 66, 218, 214, 189, 90, 143, 192, 190, 1, 103, 67, 108, 89, 72, 80, 31]),
1216             CompressedRistretto([70, 55, 107, 128, 244, 9, 178, 157, 194, 181, 246, 240, 197, 37, 145, 153, 8, 150, 229, 113, 111, 65, 71, 124, 211, 0, 133, 171, 127, 16, 48, 30]),
1217             CompressedRistretto([224, 196, 24, 247, 200, 217, 196, 205, 215, 57, 91, 147, 234, 18, 79, 58, 217, 144, 33, 187, 104, 29, 252, 51, 2, 169, 217, 154, 46, 83, 230, 78]),
1218         ];
1219         let mut bp = RistrettoPoint::identity();
1220         for i in 0..16 {
1221             assert_eq!(bp.compress(), compressed[i]);
1222             bp = &bp + &constants::RISTRETTO_BASEPOINT_POINT;
1223         }
1224     }
1225 
1226     #[test]
four_torsion_basepoint()1227     fn four_torsion_basepoint() {
1228         let bp = constants::RISTRETTO_BASEPOINT_POINT;
1229         let bp_coset = bp.coset4();
1230         for i in 0..4 {
1231             assert_eq!(bp, RistrettoPoint(bp_coset[i]));
1232         }
1233     }
1234 
1235     #[test]
four_torsion_random()1236     fn four_torsion_random() {
1237         let mut rng = OsRng;
1238         let B = &constants::RISTRETTO_BASEPOINT_TABLE;
1239         let P = B * &Scalar::random(&mut rng);
1240         let P_coset = P.coset4();
1241         for i in 0..4 {
1242             assert_eq!(P, RistrettoPoint(P_coset[i]));
1243         }
1244     }
1245 
1246     #[test]
elligator_vs_ristretto_sage()1247     fn elligator_vs_ristretto_sage() {
1248         // Test vectors extracted from ristretto.sage.
1249         //
1250         // Notice that all of the byte sequences have bit 255 set to 0; this is because
1251         // ristretto.sage does not mask the high bit of a field element.  When the high bit is set,
1252         // the ristretto.sage elligator implementation gives different results, since it takes a
1253         // different field element as input.
1254         let bytes: [[u8;32]; 16] = [
1255             [184, 249, 135, 49, 253, 123, 89, 113, 67, 160, 6, 239, 7, 105, 211, 41, 192, 249, 185, 57, 9, 102, 70, 198, 15, 127, 7, 26, 160, 102, 134, 71],
1256             [229, 14, 241, 227, 75, 9, 118, 60, 128, 153, 226, 21, 183, 217, 91, 136, 98, 0, 231, 156, 124, 77, 82, 139, 142, 134, 164, 169, 169, 62, 250, 52],
1257             [115, 109, 36, 220, 180, 223, 99, 6, 204, 169, 19, 29, 169, 68, 84, 23, 21, 109, 189, 149, 127, 205, 91, 102, 172, 35, 112, 35, 134, 69, 186, 34],
1258             [16, 49, 96, 107, 171, 199, 164, 9, 129, 16, 64, 62, 241, 63, 132, 173, 209, 160, 112, 215, 105, 50, 157, 81, 253, 105, 1, 154, 229, 25, 120, 83],
1259             [156, 131, 161, 162, 236, 251, 5, 187, 167, 171, 17, 178, 148, 210, 90, 207, 86, 21, 79, 161, 167, 215, 234, 1, 136, 242, 182, 248, 38, 85, 79, 86],
1260             [251, 177, 124, 54, 18, 101, 75, 235, 245, 186, 19, 46, 133, 157, 229, 64, 10, 136, 181, 185, 78, 144, 254, 167, 137, 49, 107, 10, 61, 10, 21, 25],
1261             [232, 193, 20, 68, 240, 77, 186, 77, 183, 40, 44, 86, 150, 31, 198, 212, 76, 81, 3, 217, 197, 8, 126, 128, 126, 152, 164, 208, 153, 44, 189, 77],
1262             [173, 229, 149, 177, 37, 230, 30, 69, 61, 56, 172, 190, 219, 115, 167, 194, 71, 134, 59, 75, 28, 244, 118, 26, 162, 97, 64, 16, 15, 189, 30, 64],
1263             [106, 71, 61, 107, 250, 117, 42, 151, 91, 202, 212, 100, 52, 188, 190, 21, 125, 218, 31, 18, 253, 241, 160, 133, 57, 242, 3, 164, 189, 68, 111, 75],
1264             [112, 204, 182, 90, 220, 198, 120, 73, 173, 107, 193, 17, 227, 40, 162, 36, 150, 141, 235, 55, 172, 183, 12, 39, 194, 136, 43, 153, 244, 118, 91, 89],
1265             [111, 24, 203, 123, 254, 189, 11, 162, 51, 196, 163, 136, 204, 143, 10, 222, 33, 112, 81, 205, 34, 35, 8, 66, 90, 6, 164, 58, 170, 177, 34, 25],
1266             [225, 183, 30, 52, 236, 82, 6, 183, 109, 25, 227, 181, 25, 82, 41, 193, 80, 77, 161, 80, 242, 203, 79, 204, 136, 245, 131, 110, 237, 106, 3, 58],
1267             [207, 246, 38, 56, 30, 86, 176, 90, 27, 200, 61, 42, 221, 27, 56, 210, 79, 178, 189, 120, 68, 193, 120, 167, 77, 185, 53, 197, 124, 128, 191, 126],
1268             [1, 136, 215, 80, 240, 46, 63, 147, 16, 244, 230, 207, 82, 189, 74, 50, 106, 169, 138, 86, 30, 131, 214, 202, 166, 125, 251, 228, 98, 24, 36, 21],
1269             [210, 207, 228, 56, 155, 116, 207, 54, 84, 195, 251, 215, 249, 199, 116, 75, 109, 239, 196, 251, 194, 246, 252, 228, 70, 146, 156, 35, 25, 39, 241, 4],
1270             [34, 116, 123, 9, 8, 40, 93, 189, 9, 103, 57, 103, 66, 227, 3, 2, 157, 107, 134, 219, 202, 74, 230, 154, 78, 107, 219, 195, 214, 14, 84, 80],
1271         ];
1272         let encoded_images: [CompressedRistretto; 16] = [
1273             CompressedRistretto([176, 157, 237, 97, 66, 29, 140, 166, 168, 94, 26, 157, 212, 216, 229, 160, 195, 246, 232, 239, 169, 112, 63, 193, 64, 32, 152, 69, 11, 190, 246, 86]),
1274             CompressedRistretto([234, 141, 77, 203, 181, 225, 250, 74, 171, 62, 15, 118, 78, 212, 150, 19, 131, 14, 188, 238, 194, 244, 141, 138, 166, 162, 83, 122, 228, 201, 19, 26]),
1275             CompressedRistretto([232, 231, 51, 92, 5, 168, 80, 36, 173, 179, 104, 68, 186, 149, 68, 40, 140, 170, 27, 103, 99, 140, 21, 242, 43, 62, 250, 134, 208, 255, 61, 89]),
1276             CompressedRistretto([208, 120, 140, 129, 177, 179, 237, 159, 252, 160, 28, 13, 206, 5, 211, 241, 192, 218, 1, 97, 130, 241, 20, 169, 119, 46, 246, 29, 79, 80, 77, 84]),
1277             CompressedRistretto([202, 11, 236, 145, 58, 12, 181, 157, 209, 6, 213, 88, 75, 147, 11, 119, 191, 139, 47, 142, 33, 36, 153, 193, 223, 183, 178, 8, 205, 120, 248, 110]),
1278             CompressedRistretto([26, 66, 231, 67, 203, 175, 116, 130, 32, 136, 62, 253, 215, 46, 5, 214, 166, 248, 108, 237, 216, 71, 244, 173, 72, 133, 82, 6, 143, 240, 104, 41]),
1279             CompressedRistretto([40, 157, 102, 96, 201, 223, 200, 197, 150, 181, 106, 83, 103, 126, 143, 33, 145, 230, 78, 6, 171, 146, 210, 143, 112, 5, 245, 23, 183, 138, 18, 120]),
1280             CompressedRistretto([220, 37, 27, 203, 239, 196, 176, 131, 37, 66, 188, 243, 185, 250, 113, 23, 167, 211, 154, 243, 168, 215, 54, 171, 159, 36, 195, 81, 13, 150, 43, 43]),
1281             CompressedRistretto([232, 121, 176, 222, 183, 196, 159, 90, 238, 193, 105, 52, 101, 167, 244, 170, 121, 114, 196, 6, 67, 152, 80, 185, 221, 7, 83, 105, 176, 208, 224, 121]),
1282             CompressedRistretto([226, 181, 183, 52, 241, 163, 61, 179, 221, 207, 220, 73, 245, 242, 25, 236, 67, 84, 179, 222, 167, 62, 167, 182, 32, 9, 92, 30, 165, 127, 204, 68]),
1283             CompressedRistretto([226, 119, 16, 242, 200, 139, 240, 87, 11, 222, 92, 146, 156, 243, 46, 119, 65, 59, 1, 248, 92, 183, 50, 175, 87, 40, 206, 53, 208, 220, 148, 13]),
1284             CompressedRistretto([70, 240, 79, 112, 54, 157, 228, 146, 74, 122, 216, 88, 232, 62, 158, 13, 14, 146, 115, 117, 176, 222, 90, 225, 244, 23, 94, 190, 150, 7, 136, 96]),
1285             CompressedRistretto([22, 71, 241, 103, 45, 193, 195, 144, 183, 101, 154, 50, 39, 68, 49, 110, 51, 44, 62, 0, 229, 113, 72, 81, 168, 29, 73, 106, 102, 40, 132, 24]),
1286             CompressedRistretto([196, 133, 107, 11, 130, 105, 74, 33, 204, 171, 133, 221, 174, 193, 241, 36, 38, 179, 196, 107, 219, 185, 181, 253, 228, 47, 155, 42, 231, 73, 41, 78]),
1287             CompressedRistretto([58, 255, 225, 197, 115, 208, 160, 143, 39, 197, 82, 69, 143, 235, 92, 170, 74, 40, 57, 11, 171, 227, 26, 185, 217, 207, 90, 185, 197, 190, 35, 60]),
1288             CompressedRistretto([88, 43, 92, 118, 223, 136, 105, 145, 238, 186, 115, 8, 214, 112, 153, 253, 38, 108, 205, 230, 157, 130, 11, 66, 101, 85, 253, 110, 110, 14, 148, 112]),
1289         ];
1290         for i in 0..16 {
1291             let r_0 = FieldElement::from_bytes(&bytes[i]);
1292             let Q = RistrettoPoint::elligator_ristretto_flavor(&r_0);
1293             assert_eq!(Q.compress(), encoded_images[i]);
1294         }
1295     }
1296 
1297     #[test]
random_roundtrip()1298     fn random_roundtrip() {
1299         let mut rng = OsRng;
1300         let B = &constants::RISTRETTO_BASEPOINT_TABLE;
1301         for _ in 0..100 {
1302             let P = B * &Scalar::random(&mut rng);
1303             let compressed_P = P.compress();
1304             let Q = compressed_P.decompress().unwrap();
1305             assert_eq!(P, Q);
1306         }
1307     }
1308 
1309     #[test]
double_and_compress_1024_random_points()1310     fn double_and_compress_1024_random_points() {
1311         let mut rng = OsRng;
1312 
1313         let points: Vec<RistrettoPoint> =
1314             (0..1024).map(|_| RistrettoPoint::random(&mut rng)).collect();
1315 
1316         let compressed = RistrettoPoint::double_and_compress_batch(&points);
1317 
1318         for (P, P2_compressed) in points.iter().zip(compressed.iter()) {
1319             assert_eq!(*P2_compressed, (P + P).compress());
1320         }
1321     }
1322 
1323     #[test]
vartime_precomputed_vs_nonprecomputed_multiscalar()1324     fn vartime_precomputed_vs_nonprecomputed_multiscalar() {
1325         let mut rng = rand::thread_rng();
1326 
1327         let B = &::constants::RISTRETTO_BASEPOINT_TABLE;
1328 
1329         let static_scalars = (0..128)
1330             .map(|_| Scalar::random(&mut rng))
1331             .collect::<Vec<_>>();
1332 
1333         let dynamic_scalars = (0..128)
1334             .map(|_| Scalar::random(&mut rng))
1335             .collect::<Vec<_>>();
1336 
1337         let check_scalar: Scalar = static_scalars
1338             .iter()
1339             .chain(dynamic_scalars.iter())
1340             .map(|s| s * s)
1341             .sum();
1342 
1343         let static_points = static_scalars.iter().map(|s| s * B).collect::<Vec<_>>();
1344         let dynamic_points = dynamic_scalars.iter().map(|s| s * B).collect::<Vec<_>>();
1345 
1346         let precomputation = VartimeRistrettoPrecomputation::new(static_points.iter());
1347 
1348         let P = precomputation.vartime_mixed_multiscalar_mul(
1349             &static_scalars,
1350             &dynamic_scalars,
1351             &dynamic_points,
1352         );
1353 
1354         use traits::VartimeMultiscalarMul;
1355         let Q = RistrettoPoint::vartime_multiscalar_mul(
1356             static_scalars.iter().chain(dynamic_scalars.iter()),
1357             static_points.iter().chain(dynamic_points.iter()),
1358         );
1359 
1360         let R = &check_scalar * B;
1361 
1362         assert_eq!(P.compress(), R.compress());
1363         assert_eq!(Q.compress(), R.compress());
1364     }
1365 }
1366