1 // -*- mode: rust; -*-
2 //
3 // This file is part of curve25519-dalek.
4 // Copyright (c) 2016-2021 isis lovecruft
5 // Copyright (c) 2016-2019 Henry de Valence
6 // See LICENSE for licensing information.
7 //
8 // Authors:
9 // - isis agora lovecruft <isis@patternsinthevoid.net>
10 // - Henry de Valence <hdevalence@hdevalence.ca>
11 
12 //! Parallel Edwards Arithmetic for Curve25519.
13 //!
14 //! This module currently has two point types:
15 //!
16 //! * `ExtendedPoint`: a point stored in vector-friendly format, with
17 //! vectorized doubling and addition;
18 //!
19 //! * `CachedPoint`: used for readdition.
20 //!
21 //! Details on the formulas can be found in the documentation for the
22 //! parent `avx2` module.
23 //!
24 //! This API is designed to be safe: vectorized points can only be
25 //! created from serial points (which do validation on decompression),
26 //! and operations on valid points return valid points, so invalid
27 //! point states should be unrepresentable.
28 //!
29 //! This design goal is met, with one exception: the `Neg`
30 //! implementation for the `CachedPoint` performs a lazy negation, so
31 //! that subtraction can be efficiently implemented as a negation and
32 //! an addition.  Repeatedly negating a `CachedPoint` will cause its
33 //! coefficients to grow and eventually overflow.  Repeatedly negating
34 //! a point should not be necessary anyways.
35 
36 #![allow(non_snake_case)]
37 
38 use core::convert::From;
39 use core::ops::{Add, Neg, Sub};
40 
41 use subtle::Choice;
42 use subtle::ConditionallySelectable;
43 
44 use edwards;
45 use window::{LookupTable, NafLookupTable5, NafLookupTable8};
46 
47 use traits::Identity;
48 
49 use super::constants;
50 use super::field::{FieldElement2625x4, Lanes, Shuffle};
51 
52 /// A point on Curve25519, using parallel Edwards formulas for curve
53 /// operations.
54 ///
55 /// # Invariant
56 ///
57 /// The coefficients of an `ExtendedPoint` are bounded with
58 /// \\( b < 0.007 \\).
59 #[derive(Copy, Clone, Debug)]
60 pub struct ExtendedPoint(pub(super) FieldElement2625x4);
61 
62 impl From<edwards::EdwardsPoint> for ExtendedPoint {
from(P: edwards::EdwardsPoint) -> ExtendedPoint63     fn from(P: edwards::EdwardsPoint) -> ExtendedPoint {
64         ExtendedPoint(FieldElement2625x4::new(&P.X, &P.Y, &P.Z, &P.T))
65     }
66 }
67 
68 impl From<ExtendedPoint> for edwards::EdwardsPoint {
from(P: ExtendedPoint) -> edwards::EdwardsPoint69     fn from(P: ExtendedPoint) -> edwards::EdwardsPoint {
70         let tmp = P.0.split();
71         edwards::EdwardsPoint {
72             X: tmp[0],
73             Y: tmp[1],
74             Z: tmp[2],
75             T: tmp[3],
76         }
77     }
78 }
79 
80 impl ConditionallySelectable for ExtendedPoint {
conditional_select(a: &Self, b: &Self, choice: Choice) -> Self81     fn conditional_select(a: &Self, b: &Self, choice: Choice) -> Self {
82         ExtendedPoint(FieldElement2625x4::conditional_select(&a.0, &b.0, choice))
83     }
84 
conditional_assign(&mut self, other: &Self, choice: Choice)85     fn conditional_assign(&mut self, other: &Self, choice: Choice) {
86         self.0.conditional_assign(&other.0, choice);
87     }
88 }
89 
90 impl Default for ExtendedPoint {
default() -> ExtendedPoint91     fn default() -> ExtendedPoint {
92         ExtendedPoint::identity()
93     }
94 }
95 
96 impl Identity for ExtendedPoint {
identity() -> ExtendedPoint97     fn identity() -> ExtendedPoint {
98         constants::EXTENDEDPOINT_IDENTITY
99     }
100 }
101 
102 impl ExtendedPoint {
103     /// Compute the double of this point.
double(&self) -> ExtendedPoint104     pub fn double(&self) -> ExtendedPoint {
105         // Want to compute (X1 Y1 Z1 X1+Y1).
106         // Not sure how to do this less expensively than computing
107         // (X1 Y1 Z1 T1) --(256bit shuffle)--> (X1 Y1 X1 Y1)
108         // (X1 Y1 X1 Y1) --(2x128b shuffle)--> (Y1 X1 Y1 X1)
109         // and then adding.
110 
111         // Set tmp0 = (X1 Y1 X1 Y1)
112         let mut tmp0 = self.0.shuffle(Shuffle::ABAB);
113 
114         // Set tmp1 = (Y1 X1 Y1 X1)
115         let mut tmp1 = tmp0.shuffle(Shuffle::BADC);
116 
117         // Set tmp0 = (X1 Y1 Z1 X1+Y1)
118         tmp0 = self.0.blend(tmp0 + tmp1, Lanes::D);
119 
120         // Set tmp1 = tmp0^2, negating the D values
121         tmp1 = tmp0.square_and_negate_D();
122         // Now tmp1 = (S1 S2 S3 -S4) with b < 0.007
123 
124         // See discussion of bounds in the module-level documentation.
125         // We want to compute
126         //
127         //    + | S1 | S1 | S1 | S1 |
128         //    + | S2 |    |    | S2 |
129         //    + |    |    | S3 |    |
130         //    + |    |    | S3 |    |
131         //    + |    |    |    |-S4 |
132         //    + |    | 2p | 2p |    |
133         //    - |    | S2 | S2 |    |
134         //    =======================
135         //        S5   S6   S8   S9
136 
137         let zero = FieldElement2625x4::zero();
138         let S_1 = tmp1.shuffle(Shuffle::AAAA);
139         let S_2 = tmp1.shuffle(Shuffle::BBBB);
140 
141         tmp0 = zero.blend(tmp1 + tmp1, Lanes::C);
142         // tmp0 = (0, 0,  2S_3, 0)
143         tmp0 = tmp0.blend(tmp1, Lanes::D);
144         // tmp0 = (0, 0,  2S_3, -S_4)
145         tmp0 = tmp0 + S_1;
146         // tmp0 = (  S_1,   S_1, S_1 + 2S_3, S_1 - S_4)
147         tmp0 = tmp0 + zero.blend(S_2, Lanes::AD);
148         // tmp0 = (S_1 + S_2,   S_1, S_1 + 2S_3, S_1 + S_2 - S_4)
149         tmp0 = tmp0 + zero.blend(S_2.negate_lazy(), Lanes::BC);
150         // tmp0 = (S_1 + S_2, S_1 - S_2, S_1 - S_2 + 2S_3, S_1 + S_2 - S_4)
151         //    b < (     1.01,       1.6,             2.33,             1.6)
152         // Now tmp0 = (S_5, S_6, S_8, S_9)
153 
154         // Set tmp1 = ( S_9,  S_6,  S_6,  S_9)
155         //        b < ( 1.6,  1.6,  1.6,  1.6)
156         tmp1 = tmp0.shuffle(Shuffle::DBBD);
157         // Set tmp0 = ( S_8,  S_5,  S_8,  S_5)
158         //        b < (2.33, 1.01, 2.33, 1.01)
159         tmp0 = tmp0.shuffle(Shuffle::CACA);
160 
161         // Bounds on (tmp0, tmp1) are (2.33, 1.6) < (2.5, 1.75).
162         ExtendedPoint(&tmp0 * &tmp1)
163     }
164 
mul_by_pow_2(&self, k: u32) -> ExtendedPoint165     pub fn mul_by_pow_2(&self, k: u32) -> ExtendedPoint {
166         let mut tmp: ExtendedPoint = *self;
167         for _ in 0..k {
168             tmp = tmp.double();
169         }
170         tmp
171     }
172 }
173 
174 /// A cached point with some precomputed variables used for readdition.
175 ///
176 /// # Warning
177 ///
178 /// It is not safe to negate this point more than once.
179 ///
180 /// # Invariant
181 ///
182 /// As long as the `CachedPoint` is not repeatedly negated, its
183 /// coefficients will be bounded with \\( b < 1.0 \\).
184 #[derive(Copy, Clone, Debug)]
185 pub struct CachedPoint(pub(super) FieldElement2625x4);
186 
187 impl From<ExtendedPoint> for CachedPoint {
from(P: ExtendedPoint) -> CachedPoint188     fn from(P: ExtendedPoint) -> CachedPoint {
189         let mut x = P.0;
190 
191         x = x.blend(x.diff_sum(), Lanes::AB);
192         // x = (Y2 - X2, Y2 + X2, Z2, T2) = (S2 S3 Z2 T2)
193 
194         x = x * (121666, 121666, 2 * 121666, 2 * 121665);
195         // x = (121666*S2 121666*S3 2*121666*Z2 2*121665*T2)
196 
197         x = x.blend(-x, Lanes::D);
198         // x = (121666*S2 121666*S3 2*121666*Z2 -2*121665*T2)
199 
200         // The coefficients of the output are bounded with b < 0.007.
201         CachedPoint(x)
202     }
203 }
204 
205 impl Default for CachedPoint {
default() -> CachedPoint206     fn default() -> CachedPoint {
207         CachedPoint::identity()
208     }
209 }
210 
211 impl Identity for CachedPoint {
identity() -> CachedPoint212     fn identity() -> CachedPoint {
213         constants::CACHEDPOINT_IDENTITY
214     }
215 }
216 
217 impl ConditionallySelectable for CachedPoint {
conditional_select(a: &Self, b: &Self, choice: Choice) -> Self218     fn conditional_select(a: &Self, b: &Self, choice: Choice) -> Self {
219         CachedPoint(FieldElement2625x4::conditional_select(&a.0, &b.0, choice))
220     }
221 
conditional_assign(&mut self, other: &Self, choice: Choice)222     fn conditional_assign(&mut self, other: &Self, choice: Choice) {
223         self.0.conditional_assign(&other.0, choice);
224     }
225 }
226 
227 impl<'a> Neg for &'a CachedPoint {
228     type Output = CachedPoint;
229     /// Lazily negate the point.
230     ///
231     /// # Warning
232     ///
233     /// Because this method does not perform a reduction, it is not
234     /// safe to repeatedly negate a point.
neg(self) -> CachedPoint235     fn neg(self) -> CachedPoint {
236         let swapped = self.0.shuffle(Shuffle::BACD);
237         CachedPoint(swapped.blend(swapped.negate_lazy(), Lanes::D))
238     }
239 }
240 
241 impl<'a, 'b> Add<&'b CachedPoint> for &'a ExtendedPoint {
242     type Output = ExtendedPoint;
243 
244     /// Add an `ExtendedPoint` and a `CachedPoint`.
add(self, other: &'b CachedPoint) -> ExtendedPoint245     fn add(self, other: &'b CachedPoint) -> ExtendedPoint {
246         // The coefficients of an `ExtendedPoint` are reduced after
247         // every operation.  If the `CachedPoint` was negated, its
248         // coefficients grow by one bit.  So on input, `self` is
249         // bounded with `b < 0.007` and `other` is bounded with
250         // `b < 1.0`.
251 
252         let mut tmp = self.0;
253 
254         tmp = tmp.blend(tmp.diff_sum(), Lanes::AB);
255         // tmp = (Y1-X1 Y1+X1 Z1 T1) = (S0 S1 Z1 T1) with b < 1.6
256 
257         // (tmp, other) bounded with b < (1.6, 1.0) < (2.5, 1.75).
258         tmp = &tmp * &other.0;
259         // tmp = (S0*S2' S1*S3' Z1*Z2' T1*T2') = (S8 S9 S10 S11)
260 
261         tmp = tmp.shuffle(Shuffle::ABDC);
262         // tmp = (S8 S9 S11 S10)
263 
264         tmp = tmp.diff_sum();
265         // tmp = (S9-S8 S9+S8 S10-S11 S10+S11) = (S12 S13 S14 S15)
266 
267         let t0 = tmp.shuffle(Shuffle::ADDA);
268         // t0 = (S12 S15 S15 S12)
269         let t1 = tmp.shuffle(Shuffle::CBCB);
270         // t1 = (S14 S13 S14 S13)
271 
272         // All coefficients of t0, t1 are bounded with b < 1.6.
273         // Return (S12*S14 S15*S13 S15*S14 S12*S13) = (X3 Y3 Z3 T3)
274         ExtendedPoint(&t0 * &t1)
275     }
276 }
277 
278 impl<'a, 'b> Sub<&'b CachedPoint> for &'a ExtendedPoint {
279     type Output = ExtendedPoint;
280 
281     /// Implement subtraction by negating the point and adding.
282     ///
283     /// Empirically, this seems about the same cost as a custom
284     /// subtraction impl (maybe because the benefit is cancelled by
285     /// increased code size?)
sub(self, other: &'b CachedPoint) -> ExtendedPoint286     fn sub(self, other: &'b CachedPoint) -> ExtendedPoint {
287         self + &(-other)
288     }
289 }
290 
291 impl<'a> From<&'a edwards::EdwardsPoint> for LookupTable<CachedPoint> {
from(point: &'a edwards::EdwardsPoint) -> Self292     fn from(point: &'a edwards::EdwardsPoint) -> Self {
293         let P = ExtendedPoint::from(*point);
294         let mut points = [CachedPoint::from(P); 8];
295         for i in 0..7 {
296             points[i + 1] = (&P + &points[i]).into();
297         }
298         LookupTable(points)
299     }
300 }
301 
302 impl<'a> From<&'a edwards::EdwardsPoint> for NafLookupTable5<CachedPoint> {
from(point: &'a edwards::EdwardsPoint) -> Self303     fn from(point: &'a edwards::EdwardsPoint) -> Self {
304         let A = ExtendedPoint::from(*point);
305         let mut Ai = [CachedPoint::from(A); 8];
306         let A2 = A.double();
307         for i in 0..7 {
308             Ai[i + 1] = (&A2 + &Ai[i]).into();
309         }
310         // Now Ai = [A, 3A, 5A, 7A, 9A, 11A, 13A, 15A]
311         NafLookupTable5(Ai)
312     }
313 }
314 
315 impl<'a> From<&'a edwards::EdwardsPoint> for NafLookupTable8<CachedPoint> {
from(point: &'a edwards::EdwardsPoint) -> Self316     fn from(point: &'a edwards::EdwardsPoint) -> Self {
317         let A = ExtendedPoint::from(*point);
318         let mut Ai = [CachedPoint::from(A); 64];
319         let A2 = A.double();
320         for i in 0..63 {
321             Ai[i + 1] = (&A2 + &Ai[i]).into();
322         }
323         // Now Ai = [A, 3A, 5A, 7A, 9A, 11A, 13A, 15A, ..., 127A]
324         NafLookupTable8(Ai)
325     }
326 }
327 
328 #[cfg(test)]
329 mod test {
330     use super::*;
331 
serial_add(P: edwards::EdwardsPoint, Q: edwards::EdwardsPoint) -> edwards::EdwardsPoint332     fn serial_add(P: edwards::EdwardsPoint, Q: edwards::EdwardsPoint) -> edwards::EdwardsPoint {
333         use backend::serial::u64::field::FieldElement51;
334 
335         let (X1, Y1, Z1, T1) = (P.X, P.Y, P.Z, P.T);
336         let (X2, Y2, Z2, T2) = (Q.X, Q.Y, Q.Z, Q.T);
337 
338         macro_rules! print_var {
339             ($x:ident) => {
340                 println!("{} = {:?}", stringify!($x), $x.to_bytes());
341             };
342         }
343 
344         let S0 = &Y1 - &X1; // R1
345         let S1 = &Y1 + &X1; // R3
346         let S2 = &Y2 - &X2; // R2
347         let S3 = &Y2 + &X2; // R4
348         print_var!(S0);
349         print_var!(S1);
350         print_var!(S2);
351         print_var!(S3);
352         println!("");
353 
354         let S4 = &S0 * &S2; // R5 = R1 * R2
355         let S5 = &S1 * &S3; // R6 = R3 * R4
356         let S6 = &Z1 * &Z2; // R8
357         let S7 = &T1 * &T2; // R7
358         print_var!(S4);
359         print_var!(S5);
360         print_var!(S6);
361         print_var!(S7);
362         println!("");
363 
364         let S8  =  &S4 *    &FieldElement51([  121666,0,0,0,0]);  // R5
365         let S9  =  &S5 *    &FieldElement51([  121666,0,0,0,0]);  // R6
366         let S10 =  &S6 *    &FieldElement51([2*121666,0,0,0,0]);  // R8
367         let S11 =  &S7 * &(-&FieldElement51([2*121665,0,0,0,0])); // R7
368         print_var!(S8);
369         print_var!(S9);
370         print_var!(S10);
371         print_var!(S11);
372         println!("");
373 
374         let S12 =  &S9 - &S8;  // R1
375         let S13 =  &S9 + &S8;  // R4
376         let S14 = &S10 - &S11; // R2
377         let S15 = &S10 + &S11; // R3
378         print_var!(S12);
379         print_var!(S13);
380         print_var!(S14);
381         print_var!(S15);
382         println!("");
383 
384         let X3 = &S12 * &S14; // R1 * R2
385         let Y3 = &S15 * &S13; // R3 * R4
386         let Z3 = &S15 * &S14; // R2 * R3
387         let T3 = &S12 * &S13; // R1 * R4
388 
389         edwards::EdwardsPoint {
390             X: X3,
391             Y: Y3,
392             Z: Z3,
393             T: T3,
394         }
395     }
396 
addition_test_helper(P: edwards::EdwardsPoint, Q: edwards::EdwardsPoint)397     fn addition_test_helper(P: edwards::EdwardsPoint, Q: edwards::EdwardsPoint) {
398         // Test the serial implementation of the parallel addition formulas
399         let R_serial: edwards::EdwardsPoint = serial_add(P.into(), Q.into()).into();
400 
401         // Test the vector implementation of the parallel readdition formulas
402         let cached_Q = CachedPoint::from(ExtendedPoint::from(Q));
403         let R_vector: edwards::EdwardsPoint = (&ExtendedPoint::from(P) + &cached_Q).into();
404         let S_vector: edwards::EdwardsPoint = (&ExtendedPoint::from(P) - &cached_Q).into();
405 
406         println!("Testing point addition:");
407         println!("P = {:?}", P);
408         println!("Q = {:?}", Q);
409         println!("cached Q = {:?}", cached_Q);
410         println!("R = P + Q = {:?}", &P + &Q);
411         println!("R_serial = {:?}", R_serial);
412         println!("R_vector = {:?}", R_vector);
413         println!("S = P - Q = {:?}", &P - &Q);
414         println!("S_vector = {:?}", S_vector);
415         assert_eq!(R_serial.compress(), (&P + &Q).compress());
416         assert_eq!(R_vector.compress(), (&P + &Q).compress());
417         assert_eq!(S_vector.compress(), (&P - &Q).compress());
418         println!("OK!\n");
419     }
420 
421     #[test]
vector_addition_vs_serial_addition_vs_edwards_extendedpoint()422     fn vector_addition_vs_serial_addition_vs_edwards_extendedpoint() {
423         use constants;
424         use scalar::Scalar;
425 
426         println!("Testing id +- id");
427         let P = edwards::EdwardsPoint::identity();
428         let Q = edwards::EdwardsPoint::identity();
429         addition_test_helper(P, Q);
430 
431         println!("Testing id +- B");
432         let P = edwards::EdwardsPoint::identity();
433         let Q = constants::ED25519_BASEPOINT_POINT;
434         addition_test_helper(P, Q);
435 
436         println!("Testing B +- B");
437         let P = constants::ED25519_BASEPOINT_POINT;
438         let Q = constants::ED25519_BASEPOINT_POINT;
439         addition_test_helper(P, Q);
440 
441         println!("Testing B +- kB");
442         let P = constants::ED25519_BASEPOINT_POINT;
443         let Q = &constants::ED25519_BASEPOINT_TABLE * &Scalar::from(8475983829u64);
444         addition_test_helper(P, Q);
445     }
446 
serial_double(P: edwards::EdwardsPoint) -> edwards::EdwardsPoint447     fn serial_double(P: edwards::EdwardsPoint) -> edwards::EdwardsPoint {
448         let (X1, Y1, Z1, _T1) = (P.X, P.Y, P.Z, P.T);
449 
450         macro_rules! print_var {
451             ($x:ident) => {
452                 println!("{} = {:?}", stringify!($x), $x.to_bytes());
453             };
454         }
455 
456         let S0 = &X1 + &Y1; // R1
457         print_var!(S0);
458         println!("");
459 
460         let S1 = X1.square();
461         let S2 = Y1.square();
462         let S3 = Z1.square();
463         let S4 = S0.square();
464         print_var!(S1);
465         print_var!(S2);
466         print_var!(S3);
467         print_var!(S4);
468         println!("");
469 
470         let S5 = &S1 + &S2;
471         let S6 = &S1 - &S2;
472         let S7 = &S3 + &S3;
473         let S8 = &S7 + &S6;
474         let S9 = &S5 - &S4;
475         print_var!(S5);
476         print_var!(S6);
477         print_var!(S7);
478         print_var!(S8);
479         print_var!(S9);
480         println!("");
481 
482         let X3 = &S8 * &S9;
483         let Y3 = &S5 * &S6;
484         let Z3 = &S8 * &S6;
485         let T3 = &S5 * &S9;
486 
487         edwards::EdwardsPoint {
488             X: X3,
489             Y: Y3,
490             Z: Z3,
491             T: T3,
492         }
493     }
494 
doubling_test_helper(P: edwards::EdwardsPoint)495     fn doubling_test_helper(P: edwards::EdwardsPoint) {
496         let R1: edwards::EdwardsPoint = serial_double(P.into()).into();
497         let R2: edwards::EdwardsPoint = ExtendedPoint::from(P).double().into();
498         println!("Testing point doubling:");
499         println!("P = {:?}", P);
500         println!("(serial) R1 = {:?}", R1);
501         println!("(vector) R2 = {:?}", R2);
502         println!("P + P = {:?}", &P + &P);
503         assert_eq!(R1.compress(), (&P + &P).compress());
504         assert_eq!(R2.compress(), (&P + &P).compress());
505         println!("OK!\n");
506     }
507 
508     #[test]
vector_doubling_vs_serial_doubling_vs_edwards_extendedpoint()509     fn vector_doubling_vs_serial_doubling_vs_edwards_extendedpoint() {
510         use constants;
511         use scalar::Scalar;
512 
513         println!("Testing [2]id");
514         let P = edwards::EdwardsPoint::identity();
515         doubling_test_helper(P);
516 
517         println!("Testing [2]B");
518         let P = constants::ED25519_BASEPOINT_POINT;
519         doubling_test_helper(P);
520 
521         println!("Testing [2]([k]B)");
522         let P = &constants::ED25519_BASEPOINT_TABLE * &Scalar::from(8475983829u64);
523         doubling_test_helper(P);
524     }
525 
526     #[test]
basepoint_odd_lookup_table_verify()527     fn basepoint_odd_lookup_table_verify() {
528         use constants;
529         use backend::vector::avx2::constants::{BASEPOINT_ODD_LOOKUP_TABLE};
530 
531         let basepoint_odd_table = NafLookupTable8::<CachedPoint>::from(&constants::ED25519_BASEPOINT_POINT);
532         println!("basepoint_odd_lookup_table = {:?}", basepoint_odd_table);
533 
534         let table_B = &BASEPOINT_ODD_LOOKUP_TABLE;
535         for (b_vec, base_vec) in table_B.0.iter().zip(basepoint_odd_table.0.iter()) {
536             let b_splits = b_vec.0.split();
537             let base_splits = base_vec.0.split();
538 
539             assert_eq!(base_splits[0], b_splits[0]);
540             assert_eq!(base_splits[1], b_splits[1]);
541             assert_eq!(base_splits[2], b_splits[2]);
542             assert_eq!(base_splits[3], b_splits[3]);
543         }
544     }
545 }
546