1 // -*- mode: rust; -*-
2 //
3 // This file is part of curve25519-dalek.
4 // Copyright (c) 2016-2019 Isis Lovecruft, Henry de Valence
5 // See LICENSE for licensing information.
6 //
7 // Authors:
8 // - Isis Agora Lovecruft <isis@patternsinthevoid.net>
9 // - Henry de Valence <hdevalence@hdevalence.ca>
10 
11 //! Parallel Edwards Arithmetic for Curve25519.
12 //!
13 //! This module currently has two point types:
14 //!
15 //! * `ExtendedPoint`: a point stored in vector-friendly format, with
16 //! vectorized doubling and addition;
17 //!
18 //! * `CachedPoint`: used for readdition.
19 //!
20 //! Details on the formulas can be found in the documentation for the
21 //! parent `avx2` module.
22 //!
23 //! This API is designed to be safe: vectorized points can only be
24 //! created from serial points (which do validation on decompression),
25 //! and operations on valid points return valid points, so invalid
26 //! point states should be unrepresentable.
27 //!
28 //! This design goal is met, with one exception: the `Neg`
29 //! implementation for the `CachedPoint` performs a lazy negation, so
30 //! that subtraction can be efficiently implemented as a negation and
31 //! an addition.  Repeatedly negating a `CachedPoint` will cause its
32 //! coefficients to grow and eventually overflow.  Repeatedly negating
33 //! a point should not be necessary anyways.
34 
35 #![allow(non_snake_case)]
36 
37 use core::convert::From;
38 use core::ops::{Add, Neg, Sub};
39 
40 use subtle::Choice;
41 use subtle::ConditionallySelectable;
42 
43 use edwards;
44 use window::{LookupTable, NafLookupTable5, NafLookupTable8};
45 
46 use traits::Identity;
47 
48 use super::constants;
49 use super::field::{FieldElement2625x4, Lanes, Shuffle};
50 
51 /// A point on Curve25519, using parallel Edwards formulas for curve
52 /// operations.
53 ///
54 /// # Invariant
55 ///
56 /// The coefficients of an `ExtendedPoint` are bounded with
57 /// \\( b < 0.007 \\).
58 #[derive(Copy, Clone, Debug)]
59 pub struct ExtendedPoint(pub(super) FieldElement2625x4);
60 
61 impl From<edwards::EdwardsPoint> for ExtendedPoint {
from(P: edwards::EdwardsPoint) -> ExtendedPoint62     fn from(P: edwards::EdwardsPoint) -> ExtendedPoint {
63         ExtendedPoint(FieldElement2625x4::new(&P.X, &P.Y, &P.Z, &P.T))
64     }
65 }
66 
67 impl From<ExtendedPoint> for edwards::EdwardsPoint {
from(P: ExtendedPoint) -> edwards::EdwardsPoint68     fn from(P: ExtendedPoint) -> edwards::EdwardsPoint {
69         let tmp = P.0.split();
70         edwards::EdwardsPoint {
71             X: tmp[0],
72             Y: tmp[1],
73             Z: tmp[2],
74             T: tmp[3],
75         }
76     }
77 }
78 
79 impl ConditionallySelectable for ExtendedPoint {
conditional_select(a: &Self, b: &Self, choice: Choice) -> Self80     fn conditional_select(a: &Self, b: &Self, choice: Choice) -> Self {
81         ExtendedPoint(FieldElement2625x4::conditional_select(&a.0, &b.0, choice))
82     }
83 
conditional_assign(&mut self, other: &Self, choice: Choice)84     fn conditional_assign(&mut self, other: &Self, choice: Choice) {
85         self.0.conditional_assign(&other.0, choice);
86     }
87 }
88 
89 impl Default for ExtendedPoint {
default() -> ExtendedPoint90     fn default() -> ExtendedPoint {
91         ExtendedPoint::identity()
92     }
93 }
94 
95 impl Identity for ExtendedPoint {
identity() -> ExtendedPoint96     fn identity() -> ExtendedPoint {
97         constants::EXTENDEDPOINT_IDENTITY
98     }
99 }
100 
101 impl ExtendedPoint {
102     /// Compute the double of this point.
double(&self) -> ExtendedPoint103     pub fn double(&self) -> ExtendedPoint {
104         // Want to compute (X1 Y1 Z1 X1+Y1).
105         // Not sure how to do this less expensively than computing
106         // (X1 Y1 Z1 T1) --(256bit shuffle)--> (X1 Y1 X1 Y1)
107         // (X1 Y1 X1 Y1) --(2x128b shuffle)--> (Y1 X1 Y1 X1)
108         // and then adding.
109 
110         // Set tmp0 = (X1 Y1 X1 Y1)
111         let mut tmp0 = self.0.shuffle(Shuffle::ABAB);
112 
113         // Set tmp1 = (Y1 X1 Y1 X1)
114         let mut tmp1 = tmp0.shuffle(Shuffle::BADC);
115 
116         // Set tmp0 = (X1 Y1 Z1 X1+Y1)
117         tmp0 = self.0.blend(tmp0 + tmp1, Lanes::D);
118 
119         // Set tmp1 = tmp0^2, negating the D values
120         tmp1 = tmp0.square_and_negate_D();
121         // Now tmp1 = (S1 S2 S3 -S4) with b < 0.007
122 
123         // See discussion of bounds in the module-level documentation.
124         // We want to compute
125         //
126         //    + | S1 | S1 | S1 | S1 |
127         //    + | S2 |    |    | S2 |
128         //    + |    |    | S3 |    |
129         //    + |    |    | S3 |    |
130         //    + |    |    |    |-S4 |
131         //    + |    | 2p | 2p |    |
132         //    - |    | S2 | S2 |    |
133         //    =======================
134         //        S5   S6   S8   S9
135 
136         let zero = FieldElement2625x4::zero();
137         let S_1 = tmp1.shuffle(Shuffle::AAAA);
138         let S_2 = tmp1.shuffle(Shuffle::BBBB);
139 
140         tmp0 = zero.blend(tmp1 + tmp1, Lanes::C);
141         // tmp0 = (0, 0,  2S_3, 0)
142         tmp0 = tmp0.blend(tmp1, Lanes::D);
143         // tmp0 = (0, 0,  2S_3, -S_4)
144         tmp0 = tmp0 + S_1;
145         // tmp0 = (  S_1,   S_1, S_1 + 2S_3, S_1 - S_4)
146         tmp0 = tmp0 + zero.blend(S_2, Lanes::AD);
147         // tmp0 = (S_1 + S_2,   S_1, S_1 + 2S_3, S_1 + S_2 - S_4)
148         tmp0 = tmp0 + zero.blend(S_2.negate_lazy(), Lanes::BC);
149         // tmp0 = (S_1 + S_2, S_1 - S_2, S_1 - S_2 + 2S_3, S_1 + S_2 - S_4)
150         //    b < (     1.01,       1.6,             2.33,             1.6)
151         // Now tmp0 = (S_5, S_6, S_8, S_9)
152 
153         // Set tmp1 = ( S_9,  S_6,  S_6,  S_9)
154         //        b < ( 1.6,  1.6,  1.6,  1.6)
155         tmp1 = tmp0.shuffle(Shuffle::DBBD);
156         // Set tmp0 = ( S_8,  S_5,  S_8,  S_5)
157         //        b < (2.33, 1.01, 2.33, 1.01)
158         tmp0 = tmp0.shuffle(Shuffle::CACA);
159 
160         // Bounds on (tmp0, tmp1) are (2.33, 1.6) < (2.5, 1.75).
161         ExtendedPoint(&tmp0 * &tmp1)
162     }
163 
mul_by_pow_2(&self, k: u32) -> ExtendedPoint164     pub fn mul_by_pow_2(&self, k: u32) -> ExtendedPoint {
165         let mut tmp: ExtendedPoint = *self;
166         for _ in 0..k {
167             tmp = tmp.double();
168         }
169         tmp
170     }
171 }
172 
173 /// A cached point with some precomputed variables used for readdition.
174 ///
175 /// # Warning
176 ///
177 /// It is not safe to negate this point more than once.
178 ///
179 /// # Invariant
180 ///
181 /// As long as the `CachedPoint` is not repeatedly negated, its
182 /// coefficients will be bounded with \\( b < 1.0 \\).
183 #[derive(Copy, Clone, Debug)]
184 pub struct CachedPoint(pub(super) FieldElement2625x4);
185 
186 impl From<ExtendedPoint> for CachedPoint {
from(P: ExtendedPoint) -> CachedPoint187     fn from(P: ExtendedPoint) -> CachedPoint {
188         let mut x = P.0;
189 
190         x = x.blend(x.diff_sum(), Lanes::AB);
191         // x = (Y2 - X2, Y2 + X2, Z2, T2) = (S2 S3 Z2 T2)
192 
193         x = x * (121666, 121666, 2 * 121666, 2 * 121665);
194         // x = (121666*S2 121666*S3 2*121666*Z2 2*121665*T2)
195 
196         x = x.blend(-x, Lanes::D);
197         // x = (121666*S2 121666*S3 2*121666*Z2 -2*121665*T2)
198 
199         // The coefficients of the output are bounded with b < 0.007.
200         CachedPoint(x)
201     }
202 }
203 
204 impl Default for CachedPoint {
default() -> CachedPoint205     fn default() -> CachedPoint {
206         CachedPoint::identity()
207     }
208 }
209 
210 impl Identity for CachedPoint {
identity() -> CachedPoint211     fn identity() -> CachedPoint {
212         constants::CACHEDPOINT_IDENTITY
213     }
214 }
215 
216 impl ConditionallySelectable for CachedPoint {
conditional_select(a: &Self, b: &Self, choice: Choice) -> Self217     fn conditional_select(a: &Self, b: &Self, choice: Choice) -> Self {
218         CachedPoint(FieldElement2625x4::conditional_select(&a.0, &b.0, choice))
219     }
220 
conditional_assign(&mut self, other: &Self, choice: Choice)221     fn conditional_assign(&mut self, other: &Self, choice: Choice) {
222         self.0.conditional_assign(&other.0, choice);
223     }
224 }
225 
226 impl<'a> Neg for &'a CachedPoint {
227     type Output = CachedPoint;
228     /// Lazily negate the point.
229     ///
230     /// # Warning
231     ///
232     /// Because this method does not perform a reduction, it is not
233     /// safe to repeatedly negate a point.
neg(self) -> CachedPoint234     fn neg(self) -> CachedPoint {
235         let swapped = self.0.shuffle(Shuffle::BACD);
236         CachedPoint(swapped.blend(swapped.negate_lazy(), Lanes::D))
237     }
238 }
239 
240 impl<'a, 'b> Add<&'b CachedPoint> for &'a ExtendedPoint {
241     type Output = ExtendedPoint;
242 
243     /// Add an `ExtendedPoint` and a `CachedPoint`.
add(self, other: &'b CachedPoint) -> ExtendedPoint244     fn add(self, other: &'b CachedPoint) -> ExtendedPoint {
245         // The coefficients of an `ExtendedPoint` are reduced after
246         // every operation.  If the `CachedPoint` was negated, its
247         // coefficients grow by one bit.  So on input, `self` is
248         // bounded with `b < 0.007` and `other` is bounded with
249         // `b < 1.0`.
250 
251         let mut tmp = self.0;
252 
253         tmp = tmp.blend(tmp.diff_sum(), Lanes::AB);
254         // tmp = (Y1-X1 Y1+X1 Z1 T1) = (S0 S1 Z1 T1) with b < 1.6
255 
256         // (tmp, other) bounded with b < (1.6, 1.0) < (2.5, 1.75).
257         tmp = &tmp * &other.0;
258         // tmp = (S0*S2' S1*S3' Z1*Z2' T1*T2') = (S8 S9 S10 S11)
259 
260         tmp = tmp.shuffle(Shuffle::ABDC);
261         // tmp = (S8 S9 S11 S10)
262 
263         tmp = tmp.diff_sum();
264         // tmp = (S9-S8 S9+S8 S10-S11 S10+S11) = (S12 S13 S14 S15)
265 
266         let t0 = tmp.shuffle(Shuffle::ADDA);
267         // t0 = (S12 S15 S15 S12)
268         let t1 = tmp.shuffle(Shuffle::CBCB);
269         // t1 = (S14 S13 S14 S13)
270 
271         // All coefficients of t0, t1 are bounded with b < 1.6.
272         // Return (S12*S14 S15*S13 S15*S14 S12*S13) = (X3 Y3 Z3 T3)
273         ExtendedPoint(&t0 * &t1)
274     }
275 }
276 
277 impl<'a, 'b> Sub<&'b CachedPoint> for &'a ExtendedPoint {
278     type Output = ExtendedPoint;
279 
280     /// Implement subtraction by negating the point and adding.
281     ///
282     /// Empirically, this seems about the same cost as a custom
283     /// subtraction impl (maybe because the benefit is cancelled by
284     /// increased code size?)
sub(self, other: &'b CachedPoint) -> ExtendedPoint285     fn sub(self, other: &'b CachedPoint) -> ExtendedPoint {
286         self + &(-other)
287     }
288 }
289 
290 impl<'a> From<&'a edwards::EdwardsPoint> for LookupTable<CachedPoint> {
from(point: &'a edwards::EdwardsPoint) -> Self291     fn from(point: &'a edwards::EdwardsPoint) -> Self {
292         let P = ExtendedPoint::from(*point);
293         let mut points = [CachedPoint::from(P); 8];
294         for i in 0..7 {
295             points[i + 1] = (&P + &points[i]).into();
296         }
297         LookupTable(points)
298     }
299 }
300 
301 impl<'a> From<&'a edwards::EdwardsPoint> for NafLookupTable5<CachedPoint> {
from(point: &'a edwards::EdwardsPoint) -> Self302     fn from(point: &'a edwards::EdwardsPoint) -> Self {
303         let A = ExtendedPoint::from(*point);
304         let mut Ai = [CachedPoint::from(A); 8];
305         let A2 = A.double();
306         for i in 0..7 {
307             Ai[i + 1] = (&A2 + &Ai[i]).into();
308         }
309         // Now Ai = [A, 3A, 5A, 7A, 9A, 11A, 13A, 15A]
310         NafLookupTable5(Ai)
311     }
312 }
313 
314 impl<'a> From<&'a edwards::EdwardsPoint> for NafLookupTable8<CachedPoint> {
from(point: &'a edwards::EdwardsPoint) -> Self315     fn from(point: &'a edwards::EdwardsPoint) -> Self {
316         let A = ExtendedPoint::from(*point);
317         let mut Ai = [CachedPoint::from(A); 64];
318         let A2 = A.double();
319         for i in 0..63 {
320             Ai[i + 1] = (&A2 + &Ai[i]).into();
321         }
322         // Now Ai = [A, 3A, 5A, 7A, 9A, 11A, 13A, 15A, ..., 127A]
323         NafLookupTable8(Ai)
324     }
325 }
326 
327 #[cfg(test)]
328 mod test {
329     use super::*;
330 
serial_add(P: edwards::EdwardsPoint, Q: edwards::EdwardsPoint) -> edwards::EdwardsPoint331     fn serial_add(P: edwards::EdwardsPoint, Q: edwards::EdwardsPoint) -> edwards::EdwardsPoint {
332         use backend::serial::u64::field::FieldElement51;
333 
334         let (X1, Y1, Z1, T1) = (P.X, P.Y, P.Z, P.T);
335         let (X2, Y2, Z2, T2) = (Q.X, Q.Y, Q.Z, Q.T);
336 
337         macro_rules! print_var {
338             ($x:ident) => {
339                 println!("{} = {:?}", stringify!($x), $x.to_bytes());
340             };
341         }
342 
343         let S0 = &Y1 - &X1; // R1
344         let S1 = &Y1 + &X1; // R3
345         let S2 = &Y2 - &X2; // R2
346         let S3 = &Y2 + &X2; // R4
347         print_var!(S0);
348         print_var!(S1);
349         print_var!(S2);
350         print_var!(S3);
351         println!("");
352 
353         let S4 = &S0 * &S2; // R5 = R1 * R2
354         let S5 = &S1 * &S3; // R6 = R3 * R4
355         let S6 = &Z1 * &Z2; // R8
356         let S7 = &T1 * &T2; // R7
357         print_var!(S4);
358         print_var!(S5);
359         print_var!(S6);
360         print_var!(S7);
361         println!("");
362 
363         let S8  =  &S4 *    &FieldElement51([  121666,0,0,0,0]);  // R5
364         let S9  =  &S5 *    &FieldElement51([  121666,0,0,0,0]);  // R6
365         let S10 =  &S6 *    &FieldElement51([2*121666,0,0,0,0]);  // R8
366         let S11 =  &S7 * &(-&FieldElement51([2*121665,0,0,0,0])); // R7
367         print_var!(S8);
368         print_var!(S9);
369         print_var!(S10);
370         print_var!(S11);
371         println!("");
372 
373         let S12 =  &S9 - &S8;  // R1
374         let S13 =  &S9 + &S8;  // R4
375         let S14 = &S10 - &S11; // R2
376         let S15 = &S10 + &S11; // R3
377         print_var!(S12);
378         print_var!(S13);
379         print_var!(S14);
380         print_var!(S15);
381         println!("");
382 
383         let X3 = &S12 * &S14; // R1 * R2
384         let Y3 = &S15 * &S13; // R3 * R4
385         let Z3 = &S15 * &S14; // R2 * R3
386         let T3 = &S12 * &S13; // R1 * R4
387 
388         edwards::EdwardsPoint {
389             X: X3,
390             Y: Y3,
391             Z: Z3,
392             T: T3,
393         }
394     }
395 
addition_test_helper(P: edwards::EdwardsPoint, Q: edwards::EdwardsPoint)396     fn addition_test_helper(P: edwards::EdwardsPoint, Q: edwards::EdwardsPoint) {
397         // Test the serial implementation of the parallel addition formulas
398         let R_serial: edwards::EdwardsPoint = serial_add(P.into(), Q.into()).into();
399 
400         // Test the vector implementation of the parallel readdition formulas
401         let cached_Q = CachedPoint::from(ExtendedPoint::from(Q));
402         let R_vector: edwards::EdwardsPoint = (&ExtendedPoint::from(P) + &cached_Q).into();
403         let S_vector: edwards::EdwardsPoint = (&ExtendedPoint::from(P) - &cached_Q).into();
404 
405         println!("Testing point addition:");
406         println!("P = {:?}", P);
407         println!("Q = {:?}", Q);
408         println!("cached Q = {:?}", cached_Q);
409         println!("R = P + Q = {:?}", &P + &Q);
410         println!("R_serial = {:?}", R_serial);
411         println!("R_vector = {:?}", R_vector);
412         println!("S = P - Q = {:?}", &P - &Q);
413         println!("S_vector = {:?}", S_vector);
414         assert_eq!(R_serial.compress(), (&P + &Q).compress());
415         assert_eq!(R_vector.compress(), (&P + &Q).compress());
416         assert_eq!(S_vector.compress(), (&P - &Q).compress());
417         println!("OK!\n");
418     }
419 
420     #[test]
vector_addition_vs_serial_addition_vs_edwards_extendedpoint()421     fn vector_addition_vs_serial_addition_vs_edwards_extendedpoint() {
422         use constants;
423         use scalar::Scalar;
424 
425         println!("Testing id +- id");
426         let P = edwards::EdwardsPoint::identity();
427         let Q = edwards::EdwardsPoint::identity();
428         addition_test_helper(P, Q);
429 
430         println!("Testing id +- B");
431         let P = edwards::EdwardsPoint::identity();
432         let Q = constants::ED25519_BASEPOINT_POINT;
433         addition_test_helper(P, Q);
434 
435         println!("Testing B +- B");
436         let P = constants::ED25519_BASEPOINT_POINT;
437         let Q = constants::ED25519_BASEPOINT_POINT;
438         addition_test_helper(P, Q);
439 
440         println!("Testing B +- kB");
441         let P = constants::ED25519_BASEPOINT_POINT;
442         let Q = &constants::ED25519_BASEPOINT_TABLE * &Scalar::from(8475983829u64);
443         addition_test_helper(P, Q);
444     }
445 
serial_double(P: edwards::EdwardsPoint) -> edwards::EdwardsPoint446     fn serial_double(P: edwards::EdwardsPoint) -> edwards::EdwardsPoint {
447         let (X1, Y1, Z1, _T1) = (P.X, P.Y, P.Z, P.T);
448 
449         macro_rules! print_var {
450             ($x:ident) => {
451                 println!("{} = {:?}", stringify!($x), $x.to_bytes());
452             };
453         }
454 
455         let S0 = &X1 + &Y1; // R1
456         print_var!(S0);
457         println!("");
458 
459         let S1 = X1.square();
460         let S2 = Y1.square();
461         let S3 = Z1.square();
462         let S4 = S0.square();
463         print_var!(S1);
464         print_var!(S2);
465         print_var!(S3);
466         print_var!(S4);
467         println!("");
468 
469         let S5 = &S1 + &S2;
470         let S6 = &S1 - &S2;
471         let S7 = &S3 + &S3;
472         let S8 = &S7 + &S6;
473         let S9 = &S5 - &S4;
474         print_var!(S5);
475         print_var!(S6);
476         print_var!(S7);
477         print_var!(S8);
478         print_var!(S9);
479         println!("");
480 
481         let X3 = &S8 * &S9;
482         let Y3 = &S5 * &S6;
483         let Z3 = &S8 * &S6;
484         let T3 = &S5 * &S9;
485 
486         edwards::EdwardsPoint {
487             X: X3,
488             Y: Y3,
489             Z: Z3,
490             T: T3,
491         }
492     }
493 
doubling_test_helper(P: edwards::EdwardsPoint)494     fn doubling_test_helper(P: edwards::EdwardsPoint) {
495         let R1: edwards::EdwardsPoint = serial_double(P.into()).into();
496         let R2: edwards::EdwardsPoint = ExtendedPoint::from(P).double().into();
497         println!("Testing point doubling:");
498         println!("P = {:?}", P);
499         println!("(serial) R1 = {:?}", R1);
500         println!("(vector) R2 = {:?}", R2);
501         println!("P + P = {:?}", &P + &P);
502         assert_eq!(R1.compress(), (&P + &P).compress());
503         assert_eq!(R2.compress(), (&P + &P).compress());
504         println!("OK!\n");
505     }
506 
507     #[test]
vector_doubling_vs_serial_doubling_vs_edwards_extendedpoint()508     fn vector_doubling_vs_serial_doubling_vs_edwards_extendedpoint() {
509         use constants;
510         use scalar::Scalar;
511 
512         println!("Testing [2]id");
513         let P = edwards::EdwardsPoint::identity();
514         doubling_test_helper(P);
515 
516         println!("Testing [2]B");
517         let P = constants::ED25519_BASEPOINT_POINT;
518         doubling_test_helper(P);
519 
520         println!("Testing [2]([k]B)");
521         let P = &constants::ED25519_BASEPOINT_TABLE * &Scalar::from(8475983829u64);
522         doubling_test_helper(P);
523     }
524 
525     #[test]
basepoint_odd_lookup_table_verify()526     fn basepoint_odd_lookup_table_verify() {
527         use constants;
528         use backend::vector::avx2::constants::{BASEPOINT_ODD_LOOKUP_TABLE};
529 
530         let basepoint_odd_table = NafLookupTable8::<CachedPoint>::from(&constants::ED25519_BASEPOINT_POINT);
531         println!("basepoint_odd_lookup_table = {:?}", basepoint_odd_table);
532 
533         let table_B = &BASEPOINT_ODD_LOOKUP_TABLE;
534         for (b_vec, base_vec) in table_B.0.iter().zip(basepoint_odd_table.0.iter()) {
535             let b_splits = b_vec.0.split();
536             let base_splits = base_vec.0.split();
537 
538             assert_eq!(base_splits[0], b_splits[0]);
539             assert_eq!(base_splits[1], b_splits[1]);
540             assert_eq!(base_splits[2], b_splits[2]);
541             assert_eq!(base_splits[3], b_splits[3]);
542         }
543     }
544 }
545