1 // -*- mode: rust; coding: utf-8; -*-
2 //
3 // This file is part of curve25519-dalek.
4 // Copyright (c) 2018-2019 Henry de Valence
5 // See LICENSE for licensing information.
6 //
7 // Authors:
8 // - Henry de Valence <hdevalence@hdevalence.ca>
9 
10 #![allow(non_snake_case)]
11 
12 use core::ops::{Add, Mul, Neg};
13 use packed_simd::{u64x4, IntoBits};
14 
15 use backend::serial::u64::field::FieldElement51;
16 
17 /// A wrapper around `vpmadd52luq` that works on `u64x4`.
18 #[inline(always)]
madd52lo(z: u64x4, x: u64x4, y: u64x4) -> u64x419 unsafe fn madd52lo(z: u64x4, x: u64x4, y: u64x4) -> u64x4 {
20     use core::arch::x86_64::_mm256_madd52lo_epu64;
21     _mm256_madd52lo_epu64(z.into_bits(), x.into_bits(), y.into_bits()).into_bits()
22 }
23 
24 /// A wrapper around `vpmadd52huq` that works on `u64x4`.
25 #[inline(always)]
madd52hi(z: u64x4, x: u64x4, y: u64x4) -> u64x426 unsafe fn madd52hi(z: u64x4, x: u64x4, y: u64x4) -> u64x4 {
27     use core::arch::x86_64::_mm256_madd52hi_epu64;
28     _mm256_madd52hi_epu64(z.into_bits(), x.into_bits(), y.into_bits()).into_bits()
29 }
30 
31 /// A vector of four field elements in radix 2^51, with unreduced coefficients.
32 #[derive(Copy, Clone, Debug)]
33 pub struct F51x4Unreduced(pub(crate) [u64x4; 5]);
34 
35 /// A vector of four field elements in radix 2^51, with reduced coefficients.
36 #[derive(Copy, Clone, Debug)]
37 pub struct F51x4Reduced(pub(crate) [u64x4; 5]);
38 
39 #[derive(Copy, Clone)]
40 pub enum Shuffle {
41     AAAA,
42     BBBB,
43     BADC,
44     BACD,
45     ADDA,
46     CBCB,
47     ABDC,
48     ABAB,
49     DBBD,
50     CACA,
51 }
52 
53 #[inline(always)]
shuffle_lanes(x: u64x4, control: Shuffle) -> u64x454 fn shuffle_lanes(x: u64x4, control: Shuffle) -> u64x4 {
55     unsafe {
56         use core::arch::x86_64::_mm256_permute4x64_epi64 as perm;
57 
58         match control {
59             Shuffle::AAAA => perm(x.into_bits(), 0b00_00_00_00).into_bits(),
60             Shuffle::BBBB => perm(x.into_bits(), 0b01_01_01_01).into_bits(),
61             Shuffle::BADC => perm(x.into_bits(), 0b10_11_00_01).into_bits(),
62             Shuffle::BACD => perm(x.into_bits(), 0b11_10_00_01).into_bits(),
63             Shuffle::ADDA => perm(x.into_bits(), 0b00_11_11_00).into_bits(),
64             Shuffle::CBCB => perm(x.into_bits(), 0b01_10_01_10).into_bits(),
65             Shuffle::ABDC => perm(x.into_bits(), 0b10_11_01_00).into_bits(),
66             Shuffle::ABAB => perm(x.into_bits(), 0b01_00_01_00).into_bits(),
67             Shuffle::DBBD => perm(x.into_bits(), 0b11_01_01_11).into_bits(),
68             Shuffle::CACA => perm(x.into_bits(), 0b00_10_00_10).into_bits(),
69         }
70     }
71 }
72 
73 #[derive(Copy, Clone)]
74 pub enum Lanes {
75     D,
76     C,
77     AB,
78     AC,
79     AD,
80     BCD,
81 }
82 
83 #[inline]
blend_lanes(x: u64x4, y: u64x4, control: Lanes) -> u64x484 fn blend_lanes(x: u64x4, y: u64x4, control: Lanes) -> u64x4 {
85     unsafe {
86         use core::arch::x86_64::_mm256_blend_epi32 as blend;
87 
88         match control {
89             Lanes::D => blend(x.into_bits(), y.into_bits(), 0b11_00_00_00).into_bits(),
90             Lanes::C => blend(x.into_bits(), y.into_bits(), 0b00_11_00_00).into_bits(),
91             Lanes::AB => blend(x.into_bits(), y.into_bits(), 0b00_00_11_11).into_bits(),
92             Lanes::AC => blend(x.into_bits(), y.into_bits(), 0b00_11_00_11).into_bits(),
93             Lanes::AD => blend(x.into_bits(), y.into_bits(), 0b11_00_00_11).into_bits(),
94             Lanes::BCD => blend(x.into_bits(), y.into_bits(), 0b11_11_11_00).into_bits(),
95         }
96     }
97 }
98 
99 impl F51x4Unreduced {
zero() -> F51x4Unreduced100     pub fn zero() -> F51x4Unreduced {
101         F51x4Unreduced([u64x4::splat(0); 5])
102     }
103 
new( x0: &FieldElement51, x1: &FieldElement51, x2: &FieldElement51, x3: &FieldElement51, ) -> F51x4Unreduced104     pub fn new(
105         x0: &FieldElement51,
106         x1: &FieldElement51,
107         x2: &FieldElement51,
108         x3: &FieldElement51,
109     ) -> F51x4Unreduced {
110         F51x4Unreduced([
111             u64x4::new(x0.0[0], x1.0[0], x2.0[0], x3.0[0]),
112             u64x4::new(x0.0[1], x1.0[1], x2.0[1], x3.0[1]),
113             u64x4::new(x0.0[2], x1.0[2], x2.0[2], x3.0[2]),
114             u64x4::new(x0.0[3], x1.0[3], x2.0[3], x3.0[3]),
115             u64x4::new(x0.0[4], x1.0[4], x2.0[4], x3.0[4]),
116         ])
117     }
118 
split(&self) -> [FieldElement51; 4]119     pub fn split(&self) -> [FieldElement51; 4] {
120         let x = &self.0;
121         [
122             FieldElement51([
123                 x[0].extract(0),
124                 x[1].extract(0),
125                 x[2].extract(0),
126                 x[3].extract(0),
127                 x[4].extract(0),
128             ]),
129             FieldElement51([
130                 x[0].extract(1),
131                 x[1].extract(1),
132                 x[2].extract(1),
133                 x[3].extract(1),
134                 x[4].extract(1),
135             ]),
136             FieldElement51([
137                 x[0].extract(2),
138                 x[1].extract(2),
139                 x[2].extract(2),
140                 x[3].extract(2),
141                 x[4].extract(2),
142             ]),
143             FieldElement51([
144                 x[0].extract(3),
145                 x[1].extract(3),
146                 x[2].extract(3),
147                 x[3].extract(3),
148                 x[4].extract(3),
149             ]),
150         ]
151     }
152 
153     #[inline]
diff_sum(&self) -> F51x4Unreduced154     pub fn diff_sum(&self) -> F51x4Unreduced {
155         // tmp1 = (B, A, D, C)
156         let tmp1 = self.shuffle(Shuffle::BADC);
157         // tmp2 = (-A, B, -C, D)
158         let tmp2 = self.blend(&self.negate_lazy(), Lanes::AC);
159         // (B - A, B + A, D - C, D + C)
160         tmp1 + tmp2
161     }
162 
163     #[inline]
negate_lazy(&self) -> F51x4Unreduced164     pub fn negate_lazy(&self) -> F51x4Unreduced {
165         let lo = u64x4::splat(36028797018963664u64);
166         let hi = u64x4::splat(36028797018963952u64);
167         F51x4Unreduced([
168             lo - self.0[0],
169             hi - self.0[1],
170             hi - self.0[2],
171             hi - self.0[3],
172             hi - self.0[4],
173         ])
174     }
175 
176     #[inline]
shuffle(&self, control: Shuffle) -> F51x4Unreduced177     pub fn shuffle(&self, control: Shuffle) -> F51x4Unreduced {
178         F51x4Unreduced([
179             shuffle_lanes(self.0[0], control),
180             shuffle_lanes(self.0[1], control),
181             shuffle_lanes(self.0[2], control),
182             shuffle_lanes(self.0[3], control),
183             shuffle_lanes(self.0[4], control),
184         ])
185     }
186 
187     #[inline]
blend(&self, other: &F51x4Unreduced, control: Lanes) -> F51x4Unreduced188     pub fn blend(&self, other: &F51x4Unreduced, control: Lanes) -> F51x4Unreduced {
189         F51x4Unreduced([
190             blend_lanes(self.0[0], other.0[0], control),
191             blend_lanes(self.0[1], other.0[1], control),
192             blend_lanes(self.0[2], other.0[2], control),
193             blend_lanes(self.0[3], other.0[3], control),
194             blend_lanes(self.0[4], other.0[4], control),
195         ])
196     }
197 }
198 
199 impl Neg for F51x4Reduced {
200     type Output = F51x4Reduced;
201 
neg(self) -> F51x4Reduced202     fn neg(self) -> F51x4Reduced {
203         F51x4Unreduced::from(self).negate_lazy().into()
204     }
205 }
206 
207 use subtle::Choice;
208 use subtle::ConditionallySelectable;
209 
210 impl ConditionallySelectable for F51x4Reduced {
211     #[inline]
conditional_select(a: &F51x4Reduced, b: &F51x4Reduced, choice: Choice) -> F51x4Reduced212     fn conditional_select(a: &F51x4Reduced, b: &F51x4Reduced, choice: Choice) -> F51x4Reduced {
213         let mask = (-(choice.unwrap_u8() as i64)) as u64;
214         let mask_vec = u64x4::splat(mask);
215         F51x4Reduced([
216             a.0[0] ^ (mask_vec & (a.0[0] ^ b.0[0])),
217             a.0[1] ^ (mask_vec & (a.0[1] ^ b.0[1])),
218             a.0[2] ^ (mask_vec & (a.0[2] ^ b.0[2])),
219             a.0[3] ^ (mask_vec & (a.0[3] ^ b.0[3])),
220             a.0[4] ^ (mask_vec & (a.0[4] ^ b.0[4])),
221         ])
222     }
223 
224     #[inline]
conditional_assign(&mut self, other: &F51x4Reduced, choice: Choice)225     fn conditional_assign(&mut self, other: &F51x4Reduced, choice: Choice) {
226         let mask = (-(choice.unwrap_u8() as i64)) as u64;
227         let mask_vec = u64x4::splat(mask);
228         self.0[0] ^= mask_vec & (self.0[0] ^ other.0[0]);
229         self.0[1] ^= mask_vec & (self.0[1] ^ other.0[1]);
230         self.0[2] ^= mask_vec & (self.0[2] ^ other.0[2]);
231         self.0[3] ^= mask_vec & (self.0[3] ^ other.0[3]);
232         self.0[4] ^= mask_vec & (self.0[4] ^ other.0[4]);
233     }
234 }
235 
236 impl F51x4Reduced {
237     #[inline]
shuffle(&self, control: Shuffle) -> F51x4Reduced238     pub fn shuffle(&self, control: Shuffle) -> F51x4Reduced {
239         F51x4Reduced([
240             shuffle_lanes(self.0[0], control),
241             shuffle_lanes(self.0[1], control),
242             shuffle_lanes(self.0[2], control),
243             shuffle_lanes(self.0[3], control),
244             shuffle_lanes(self.0[4], control),
245         ])
246     }
247 
248     #[inline]
blend(&self, other: &F51x4Reduced, control: Lanes) -> F51x4Reduced249     pub fn blend(&self, other: &F51x4Reduced, control: Lanes) -> F51x4Reduced {
250         F51x4Reduced([
251             blend_lanes(self.0[0], other.0[0], control),
252             blend_lanes(self.0[1], other.0[1], control),
253             blend_lanes(self.0[2], other.0[2], control),
254             blend_lanes(self.0[3], other.0[3], control),
255             blend_lanes(self.0[4], other.0[4], control),
256         ])
257     }
258 
259     #[inline]
square(&self) -> F51x4Unreduced260     pub fn square(&self) -> F51x4Unreduced {
261         unsafe {
262             let x = &self.0;
263 
264             // Represent values with coeff. 2
265             let mut z0_2 = u64x4::splat(0);
266             let mut z1_2 = u64x4::splat(0);
267             let mut z2_2 = u64x4::splat(0);
268             let mut z3_2 = u64x4::splat(0);
269             let mut z4_2 = u64x4::splat(0);
270             let mut z5_2 = u64x4::splat(0);
271             let mut z6_2 = u64x4::splat(0);
272             let mut z7_2 = u64x4::splat(0);
273             let mut z9_2 = u64x4::splat(0);
274 
275             // Represent values with coeff. 4
276             let mut z2_4 = u64x4::splat(0);
277             let mut z3_4 = u64x4::splat(0);
278             let mut z4_4 = u64x4::splat(0);
279             let mut z5_4 = u64x4::splat(0);
280             let mut z6_4 = u64x4::splat(0);
281             let mut z7_4 = u64x4::splat(0);
282             let mut z8_4 = u64x4::splat(0);
283 
284             let mut z0_1 = u64x4::splat(0);
285             z0_1 = madd52lo(z0_1, x[0], x[0]);
286 
287             let mut z1_1 = u64x4::splat(0);
288             z1_2 = madd52lo(z1_2, x[0], x[1]);
289             z1_2 = madd52hi(z1_2, x[0], x[0]);
290 
291             z2_4 = madd52hi(z2_4, x[0], x[1]);
292             let mut z2_1 = z2_4 << 2;
293             z2_2 = madd52lo(z2_2, x[0], x[2]);
294             z2_1 = madd52lo(z2_1, x[1], x[1]);
295 
296             z3_4 = madd52hi(z3_4, x[0], x[2]);
297             let mut z3_1 = z3_4 << 2;
298             z3_2 = madd52lo(z3_2, x[1], x[2]);
299             z3_2 = madd52lo(z3_2, x[0], x[3]);
300             z3_2 = madd52hi(z3_2, x[1], x[1]);
301 
302             z4_4 = madd52hi(z4_4, x[1], x[2]);
303             z4_4 = madd52hi(z4_4, x[0], x[3]);
304             let mut z4_1 = z4_4 << 2;
305             z4_2 = madd52lo(z4_2, x[1], x[3]);
306             z4_2 = madd52lo(z4_2, x[0], x[4]);
307             z4_1 = madd52lo(z4_1, x[2], x[2]);
308 
309             z5_4 = madd52hi(z5_4, x[1], x[3]);
310             z5_4 = madd52hi(z5_4, x[0], x[4]);
311             let mut z5_1 = z5_4 << 2;
312             z5_2 = madd52lo(z5_2, x[2], x[3]);
313             z5_2 = madd52lo(z5_2, x[1], x[4]);
314             z5_2 = madd52hi(z5_2, x[2], x[2]);
315 
316             z6_4 = madd52hi(z6_4, x[2], x[3]);
317             z6_4 = madd52hi(z6_4, x[1], x[4]);
318             let mut z6_1 = z6_4 << 2;
319             z6_2 = madd52lo(z6_2, x[2], x[4]);
320             z6_1 = madd52lo(z6_1, x[3], x[3]);
321 
322             z7_4 = madd52hi(z7_4, x[2], x[4]);
323             let mut z7_1 = z7_4 << 2;
324             z7_2 = madd52lo(z7_2, x[3], x[4]);
325             z7_2 = madd52hi(z7_2, x[3], x[3]);
326 
327             z8_4 = madd52hi(z8_4, x[3], x[4]);
328             let mut z8_1 = z8_4 << 2;
329             z8_1 = madd52lo(z8_1, x[4], x[4]);
330 
331             let mut z9_1 = u64x4::splat(0);
332             z9_2 = madd52hi(z9_2, x[4], x[4]);
333 
334             z5_1 += z5_2 << 1;
335             z6_1 += z6_2 << 1;
336             z7_1 += z7_2 << 1;
337             z9_1 += z9_2 << 1;
338 
339             let mut t0 = u64x4::splat(0);
340             let mut t1 = u64x4::splat(0);
341             let r19 = u64x4::splat(19);
342 
343             t0 = madd52hi(t0, r19, z9_1);
344             t1 = madd52lo(t1, r19, z9_1 >> 52);
345 
346             z4_2 = madd52lo(z4_2, r19, z8_1 >> 52);
347             z3_2 = madd52lo(z3_2, r19, z7_1 >> 52);
348             z2_2 = madd52lo(z2_2, r19, z6_1 >> 52);
349             z1_2 = madd52lo(z1_2, r19, z5_1 >> 52);
350 
351             z0_2 = madd52lo(z0_2, r19, t0 + t1);
352             z1_2 = madd52hi(z1_2, r19, z5_1);
353             z2_2 = madd52hi(z2_2, r19, z6_1);
354             z3_2 = madd52hi(z3_2, r19, z7_1);
355             z4_2 = madd52hi(z4_2, r19, z8_1);
356 
357             z0_1 = madd52lo(z0_1, r19, z5_1);
358             z1_1 = madd52lo(z1_1, r19, z6_1);
359             z2_1 = madd52lo(z2_1, r19, z7_1);
360             z3_1 = madd52lo(z3_1, r19, z8_1);
361             z4_1 = madd52lo(z4_1, r19, z9_1);
362 
363             F51x4Unreduced([
364                 z0_1 + z0_2 + z0_2,
365                 z1_1 + z1_2 + z1_2,
366                 z2_1 + z2_2 + z2_2,
367                 z3_1 + z3_2 + z3_2,
368                 z4_1 + z4_2 + z4_2,
369             ])
370         }
371     }
372 }
373 
374 impl From<F51x4Reduced> for F51x4Unreduced {
375     #[inline]
from(x: F51x4Reduced) -> F51x4Unreduced376     fn from(x: F51x4Reduced) -> F51x4Unreduced {
377         F51x4Unreduced(x.0)
378     }
379 }
380 
381 impl From<F51x4Unreduced> for F51x4Reduced {
382     #[inline]
from(x: F51x4Unreduced) -> F51x4Reduced383     fn from(x: F51x4Unreduced) -> F51x4Reduced {
384         let mask = u64x4::splat((1 << 51) - 1);
385         let r19 = u64x4::splat(19);
386 
387         // Compute carryouts in parallel
388         let c0 = x.0[0] >> 51;
389         let c1 = x.0[1] >> 51;
390         let c2 = x.0[2] >> 51;
391         let c3 = x.0[3] >> 51;
392         let c4 = x.0[4] >> 51;
393 
394         unsafe {
395             F51x4Reduced([
396                 madd52lo(x.0[0] & mask, c4, r19),
397                 (x.0[1] & mask) + c0,
398                 (x.0[2] & mask) + c1,
399                 (x.0[3] & mask) + c2,
400                 (x.0[4] & mask) + c3,
401             ])
402         }
403     }
404 }
405 
406 impl Add<F51x4Unreduced> for F51x4Unreduced {
407     type Output = F51x4Unreduced;
408     #[inline]
add(self, rhs: F51x4Unreduced) -> F51x4Unreduced409     fn add(self, rhs: F51x4Unreduced) -> F51x4Unreduced {
410         F51x4Unreduced([
411             self.0[0] + rhs.0[0],
412             self.0[1] + rhs.0[1],
413             self.0[2] + rhs.0[2],
414             self.0[3] + rhs.0[3],
415             self.0[4] + rhs.0[4],
416         ])
417     }
418 }
419 
420 impl<'a> Mul<(u32, u32, u32, u32)> for &'a F51x4Reduced {
421     type Output = F51x4Unreduced;
422     #[inline]
mul(self, scalars: (u32, u32, u32, u32)) -> F51x4Unreduced423     fn mul(self, scalars: (u32, u32, u32, u32)) -> F51x4Unreduced {
424         unsafe {
425             let x = &self.0;
426             let y = u64x4::new(
427                 scalars.0 as u64,
428                 scalars.1 as u64,
429                 scalars.2 as u64,
430                 scalars.3 as u64,
431             );
432             let r19 = u64x4::splat(19);
433 
434             let mut z0_1 = u64x4::splat(0);
435             let mut z1_1 = u64x4::splat(0);
436             let mut z2_1 = u64x4::splat(0);
437             let mut z3_1 = u64x4::splat(0);
438             let mut z4_1 = u64x4::splat(0);
439             let mut z1_2 = u64x4::splat(0);
440             let mut z2_2 = u64x4::splat(0);
441             let mut z3_2 = u64x4::splat(0);
442             let mut z4_2 = u64x4::splat(0);
443             let mut z5_2 = u64x4::splat(0);
444 
445             // Wave 0
446             z4_2 = madd52hi(z4_2, y, x[3]);
447             z5_2 = madd52hi(z5_2, y, x[4]);
448             z4_1 = madd52lo(z4_1, y, x[4]);
449             z0_1 = madd52lo(z0_1, y, x[0]);
450             z3_1 = madd52lo(z3_1, y, x[3]);
451             z2_1 = madd52lo(z2_1, y, x[2]);
452             z1_1 = madd52lo(z1_1, y, x[1]);
453             z3_2 = madd52hi(z3_2, y, x[2]);
454 
455             // Wave 2
456             z2_2 = madd52hi(z2_2, y, x[1]);
457             z1_2 = madd52hi(z1_2, y, x[0]);
458             z0_1 = madd52lo(z0_1, z5_2 + z5_2, r19);
459 
460             F51x4Unreduced([
461                 z0_1,
462                 z1_1 + z1_2 + z1_2,
463                 z2_1 + z2_2 + z2_2,
464                 z3_1 + z3_2 + z3_2,
465                 z4_1 + z4_2 + z4_2,
466             ])
467         }
468     }
469 }
470 
471 impl<'a, 'b> Mul<&'b F51x4Reduced> for &'a F51x4Reduced {
472     type Output = F51x4Unreduced;
473     #[inline]
mul(self, rhs: &'b F51x4Reduced) -> F51x4Unreduced474     fn mul(self, rhs: &'b F51x4Reduced) -> F51x4Unreduced {
475         unsafe {
476             // Inputs
477             let x = &self.0;
478             let y = &rhs.0;
479 
480             // Accumulators for terms with coeff 1
481             let mut z0_1 = u64x4::splat(0);
482             let mut z1_1 = u64x4::splat(0);
483             let mut z2_1 = u64x4::splat(0);
484             let mut z3_1 = u64x4::splat(0);
485             let mut z4_1 = u64x4::splat(0);
486             let mut z5_1 = u64x4::splat(0);
487             let mut z6_1 = u64x4::splat(0);
488             let mut z7_1 = u64x4::splat(0);
489             let mut z8_1 = u64x4::splat(0);
490 
491             // Accumulators for terms with coeff 2
492             let mut z0_2 = u64x4::splat(0);
493             let mut z1_2 = u64x4::splat(0);
494             let mut z2_2 = u64x4::splat(0);
495             let mut z3_2 = u64x4::splat(0);
496             let mut z4_2 = u64x4::splat(0);
497             let mut z5_2 = u64x4::splat(0);
498             let mut z6_2 = u64x4::splat(0);
499             let mut z7_2 = u64x4::splat(0);
500             let mut z8_2 = u64x4::splat(0);
501             let mut z9_2 = u64x4::splat(0);
502 
503             // LLVM doesn't seem to do much work reordering IFMA
504             // instructions, so try to organize them into "waves" of 8
505             // independent operations (4c latency, 0.5 c throughput
506             // means 8 in flight)
507 
508             // Wave 0
509             z4_1 = madd52lo(z4_1, x[2], y[2]);
510             z5_2 = madd52hi(z5_2, x[2], y[2]);
511             z5_1 = madd52lo(z5_1, x[4], y[1]);
512             z6_2 = madd52hi(z6_2, x[4], y[1]);
513             z6_1 = madd52lo(z6_1, x[4], y[2]);
514             z7_2 = madd52hi(z7_2, x[4], y[2]);
515             z7_1 = madd52lo(z7_1, x[4], y[3]);
516             z8_2 = madd52hi(z8_2, x[4], y[3]);
517 
518             // Wave 1
519             z4_1 = madd52lo(z4_1, x[3], y[1]);
520             z5_2 = madd52hi(z5_2, x[3], y[1]);
521             z5_1 = madd52lo(z5_1, x[3], y[2]);
522             z6_2 = madd52hi(z6_2, x[3], y[2]);
523             z6_1 = madd52lo(z6_1, x[3], y[3]);
524             z7_2 = madd52hi(z7_2, x[3], y[3]);
525             z7_1 = madd52lo(z7_1, x[3], y[4]);
526             z8_2 = madd52hi(z8_2, x[3], y[4]);
527 
528             // Wave 2
529             z8_1 = madd52lo(z8_1, x[4], y[4]);
530             z9_2 = madd52hi(z9_2, x[4], y[4]);
531             z4_1 = madd52lo(z4_1, x[4], y[0]);
532             z5_2 = madd52hi(z5_2, x[4], y[0]);
533             z5_1 = madd52lo(z5_1, x[2], y[3]);
534             z6_2 = madd52hi(z6_2, x[2], y[3]);
535             z6_1 = madd52lo(z6_1, x[2], y[4]);
536             z7_2 = madd52hi(z7_2, x[2], y[4]);
537 
538             let z8 = z8_1 + z8_2 + z8_2;
539             let z9 = z9_2 + z9_2;
540 
541             // Wave 3
542             z3_1 = madd52lo(z3_1, x[3], y[0]);
543             z4_2 = madd52hi(z4_2, x[3], y[0]);
544             z4_1 = madd52lo(z4_1, x[1], y[3]);
545             z5_2 = madd52hi(z5_2, x[1], y[3]);
546             z5_1 = madd52lo(z5_1, x[1], y[4]);
547             z6_2 = madd52hi(z6_2, x[1], y[4]);
548             z2_1 = madd52lo(z2_1, x[2], y[0]);
549             z3_2 = madd52hi(z3_2, x[2], y[0]);
550 
551             let z6 = z6_1 + z6_2 + z6_2;
552             let z7 = z7_1 + z7_2 + z7_2;
553 
554             // Wave 4
555             z3_1 = madd52lo(z3_1, x[2], y[1]);
556             z4_2 = madd52hi(z4_2, x[2], y[1]);
557             z4_1 = madd52lo(z4_1, x[0], y[4]);
558             z5_2 = madd52hi(z5_2, x[0], y[4]);
559             z1_1 = madd52lo(z1_1, x[1], y[0]);
560             z2_2 = madd52hi(z2_2, x[1], y[0]);
561             z2_1 = madd52lo(z2_1, x[1], y[1]);
562             z3_2 = madd52hi(z3_2, x[1], y[1]);
563 
564             let z5 = z5_1 + z5_2 + z5_2;
565 
566             // Wave 5
567             z3_1 = madd52lo(z3_1, x[1], y[2]);
568             z4_2 = madd52hi(z4_2, x[1], y[2]);
569             z0_1 = madd52lo(z0_1, x[0], y[0]);
570             z1_2 = madd52hi(z1_2, x[0], y[0]);
571             z1_1 = madd52lo(z1_1, x[0], y[1]);
572             z2_1 = madd52lo(z2_1, x[0], y[2]);
573             z2_2 = madd52hi(z2_2, x[0], y[1]);
574             z3_2 = madd52hi(z3_2, x[0], y[2]);
575 
576             let mut t0 = u64x4::splat(0);
577             let mut t1 = u64x4::splat(0);
578             let r19 = u64x4::splat(19);
579 
580             // Wave 6
581             t0 = madd52hi(t0, r19, z9);
582             t1 = madd52lo(t1, r19, z9 >> 52);
583             z3_1 = madd52lo(z3_1, x[0], y[3]);
584             z4_2 = madd52hi(z4_2, x[0], y[3]);
585             z1_2 = madd52lo(z1_2, r19, z5 >> 52);
586             z2_2 = madd52lo(z2_2, r19, z6 >> 52);
587             z3_2 = madd52lo(z3_2, r19, z7 >> 52);
588             z0_1 = madd52lo(z0_1, r19, z5);
589 
590             // Wave 7
591             z4_1 = madd52lo(z4_1, r19, z9);
592             z1_1 = madd52lo(z1_1, r19, z6);
593             z0_2 = madd52lo(z0_2, r19, t0 + t1);
594             z4_2 = madd52hi(z4_2, r19, z8);
595             z2_1 = madd52lo(z2_1, r19, z7);
596             z1_2 = madd52hi(z1_2, r19, z5);
597             z2_2 = madd52hi(z2_2, r19, z6);
598             z3_2 = madd52hi(z3_2, r19, z7);
599 
600             // Wave 8
601             z3_1 = madd52lo(z3_1, r19, z8);
602             z4_2 = madd52lo(z4_2, r19, z8 >> 52);
603 
604             F51x4Unreduced([
605                 z0_1 + z0_2 + z0_2,
606                 z1_1 + z1_2 + z1_2,
607                 z2_1 + z2_2 + z2_2,
608                 z3_1 + z3_2 + z3_2,
609                 z4_1 + z4_2 + z4_2,
610             ])
611         }
612     }
613 }
614 
615 #[cfg(test)]
616 mod test {
617     use super::*;
618 
619     #[test]
vpmadd52luq()620     fn vpmadd52luq() {
621         let x = u64x4::splat(2);
622         let y = u64x4::splat(3);
623         let mut z = u64x4::splat(5);
624 
625         z = unsafe { madd52lo(z, x, y) };
626 
627         assert_eq!(z, u64x4::splat(5 + 2 * 3));
628     }
629 
630     #[test]
new_split_round_trip_on_reduced_input()631     fn new_split_round_trip_on_reduced_input() {
632         // Invert a small field element to get a big one
633         let a = FieldElement51([2438, 24, 243, 0, 0]).invert();
634 
635         let ax4 = F51x4Unreduced::new(&a, &a, &a, &a);
636         let splits = ax4.split();
637 
638         for i in 0..4 {
639             assert_eq!(a, splits[i]);
640         }
641     }
642 
643     #[test]
new_split_round_trip_on_unreduced_input()644     fn new_split_round_trip_on_unreduced_input() {
645         // Invert a small field element to get a big one
646         let a = FieldElement51([2438, 24, 243, 0, 0]).invert();
647         // ... but now multiply it by 16 without reducing coeffs
648         let a16 = FieldElement51([
649             a.0[0] << 4,
650             a.0[1] << 4,
651             a.0[2] << 4,
652             a.0[3] << 4,
653             a.0[4] << 4,
654         ]);
655 
656         let a16x4 = F51x4Unreduced::new(&a16, &a16, &a16, &a16);
657         let splits = a16x4.split();
658 
659         for i in 0..4 {
660             assert_eq!(a16, splits[i]);
661         }
662     }
663 
664     #[test]
test_reduction()665     fn test_reduction() {
666         // Invert a small field element to get a big one
667         let a = FieldElement51([2438, 24, 243, 0, 0]).invert();
668         // ... but now multiply it by 128 without reducing coeffs
669         let abig = FieldElement51([
670             a.0[0] << 4,
671             a.0[1] << 4,
672             a.0[2] << 4,
673             a.0[3] << 4,
674             a.0[4] << 4,
675         ]);
676 
677         let abigx4: F51x4Reduced = F51x4Unreduced::new(&abig, &abig, &abig, &abig).into();
678 
679         let splits = F51x4Unreduced::from(abigx4).split();
680         let c = &a * &FieldElement51([(1 << 4), 0, 0, 0, 0]);
681 
682         for i in 0..4 {
683             assert_eq!(c, splits[i]);
684         }
685     }
686 
687     #[test]
mul_matches_serial()688     fn mul_matches_serial() {
689         // Invert a small field element to get a big one
690         let a = FieldElement51([2438, 24, 243, 0, 0]).invert();
691         let b = FieldElement51([98098, 87987897, 0, 1, 0]).invert();
692         let c = &a * &b;
693 
694         let ax4: F51x4Reduced = F51x4Unreduced::new(&a, &a, &a, &a).into();
695         let bx4: F51x4Reduced = F51x4Unreduced::new(&b, &b, &b, &b).into();
696         let cx4 = &ax4 * &bx4;
697 
698         let splits = cx4.split();
699 
700         for i in 0..4 {
701             assert_eq!(c, splits[i]);
702         }
703     }
704 
705     #[test]
iterated_mul_matches_serial()706     fn iterated_mul_matches_serial() {
707         // Invert a small field element to get a big one
708         let a = FieldElement51([2438, 24, 243, 0, 0]).invert();
709         let b = FieldElement51([98098, 87987897, 0, 1, 0]).invert();
710         let mut c = &a * &b;
711         for _i in 0..1024 {
712             c = &a * &c;
713             c = &b * &c;
714         }
715 
716         let ax4: F51x4Reduced = F51x4Unreduced::new(&a, &a, &a, &a).into();
717         let bx4: F51x4Reduced = F51x4Unreduced::new(&b, &b, &b, &b).into();
718         let mut cx4 = &ax4 * &bx4;
719         for _i in 0..1024 {
720             cx4 = &ax4 * &F51x4Reduced::from(cx4);
721             cx4 = &bx4 * &F51x4Reduced::from(cx4);
722         }
723 
724         let splits = cx4.split();
725 
726         for i in 0..4 {
727             assert_eq!(c, splits[i]);
728         }
729     }
730 
731     #[test]
square_matches_mul()732     fn square_matches_mul() {
733         // Invert a small field element to get a big one
734         let a = FieldElement51([2438, 24, 243, 0, 0]).invert();
735 
736         let ax4: F51x4Reduced = F51x4Unreduced::new(&a, &a, &a, &a).into();
737         let cx4 = &ax4 * &ax4;
738         let cx4_sq = ax4.square();
739 
740         let splits = cx4.split();
741         let splits_sq = cx4_sq.split();
742 
743         for i in 0..4 {
744             assert_eq!(splits_sq[i], splits[i]);
745         }
746     }
747 
748     #[test]
iterated_square_matches_serial()749     fn iterated_square_matches_serial() {
750         // Invert a small field element to get a big one
751         let mut a = FieldElement51([2438, 24, 243, 0, 0]).invert();
752         let mut ax4 = F51x4Unreduced::new(&a, &a, &a, &a);
753         for _j in 0..1024 {
754             a = a.square();
755             ax4 = F51x4Reduced::from(ax4).square();
756 
757             let splits = ax4.split();
758             for i in 0..4 {
759                 assert_eq!(a, splits[i]);
760             }
761         }
762     }
763 
764     #[test]
iterated_u32_mul_matches_serial()765     fn iterated_u32_mul_matches_serial() {
766         // Invert a small field element to get a big one
767         let a = FieldElement51([2438, 24, 243, 0, 0]).invert();
768         let b = FieldElement51([121665, 0, 0, 0, 0]);
769         let mut c = &a * &b;
770         for _i in 0..1024 {
771             c = &b * &c;
772         }
773 
774         let ax4 = F51x4Unreduced::new(&a, &a, &a, &a);
775         let bx4 = (121665u32, 121665u32, 121665u32, 121665u32);
776         let mut cx4 = &F51x4Reduced::from(ax4) * bx4;
777         for _i in 0..1024 {
778             cx4 = &F51x4Reduced::from(cx4) * bx4;
779         }
780 
781         let splits = cx4.split();
782 
783         for i in 0..4 {
784             assert_eq!(c, splits[i]);
785         }
786     }
787 
788     #[test]
shuffle_AAAA()789     fn shuffle_AAAA() {
790         let x0 = FieldElement51::from_bytes(&[0x10; 32]);
791         let x1 = FieldElement51::from_bytes(&[0x11; 32]);
792         let x2 = FieldElement51::from_bytes(&[0x12; 32]);
793         let x3 = FieldElement51::from_bytes(&[0x13; 32]);
794 
795         let x = F51x4Unreduced::new(&x0, &x1, &x2, &x3);
796 
797         let y = x.shuffle(Shuffle::AAAA);
798         let splits = y.split();
799 
800         assert_eq!(splits[0], x0);
801         assert_eq!(splits[1], x0);
802         assert_eq!(splits[2], x0);
803         assert_eq!(splits[3], x0);
804     }
805 
806     #[test]
blend_AB()807     fn blend_AB() {
808         let x0 = FieldElement51::from_bytes(&[0x10; 32]);
809         let x1 = FieldElement51::from_bytes(&[0x11; 32]);
810         let x2 = FieldElement51::from_bytes(&[0x12; 32]);
811         let x3 = FieldElement51::from_bytes(&[0x13; 32]);
812 
813         let x = F51x4Unreduced::new(&x0, &x1, &x2, &x3);
814         let z = F51x4Unreduced::new(&x3, &x2, &x1, &x0);
815 
816         let y = x.blend(&z, Lanes::AB);
817         let splits = y.split();
818 
819         assert_eq!(splits[0], x3);
820         assert_eq!(splits[1], x2);
821         assert_eq!(splits[2], x2);
822         assert_eq!(splits[3], x3);
823     }
824 }
825