1 // -*- mode: rust; coding: utf-8; -*-
2 //
3 // This file is part of curve25519-dalek.
4 // Copyright (c) 2018-2019 Henry de Valence
5 // See LICENSE for licensing information.
6 //
7 // Authors:
8 // - Henry de Valence <hdevalence@hdevalence.ca>
9
10 #![allow(non_snake_case)]
11
12 use core::ops::{Add, Mul, Neg};
13 use packed_simd::{u64x4, IntoBits};
14
15 use backend::serial::u64::field::FieldElement51;
16
17 /// A wrapper around `vpmadd52luq` that works on `u64x4`.
18 #[inline(always)]
madd52lo(z: u64x4, x: u64x4, y: u64x4) -> u64x419 unsafe fn madd52lo(z: u64x4, x: u64x4, y: u64x4) -> u64x4 {
20 use core::arch::x86_64::_mm256_madd52lo_epu64;
21 _mm256_madd52lo_epu64(z.into_bits(), x.into_bits(), y.into_bits()).into_bits()
22 }
23
24 /// A wrapper around `vpmadd52huq` that works on `u64x4`.
25 #[inline(always)]
madd52hi(z: u64x4, x: u64x4, y: u64x4) -> u64x426 unsafe fn madd52hi(z: u64x4, x: u64x4, y: u64x4) -> u64x4 {
27 use core::arch::x86_64::_mm256_madd52hi_epu64;
28 _mm256_madd52hi_epu64(z.into_bits(), x.into_bits(), y.into_bits()).into_bits()
29 }
30
31 /// A vector of four field elements in radix 2^51, with unreduced coefficients.
32 #[derive(Copy, Clone, Debug)]
33 pub struct F51x4Unreduced(pub(crate) [u64x4; 5]);
34
35 /// A vector of four field elements in radix 2^51, with reduced coefficients.
36 #[derive(Copy, Clone, Debug)]
37 pub struct F51x4Reduced(pub(crate) [u64x4; 5]);
38
39 #[derive(Copy, Clone)]
40 pub enum Shuffle {
41 AAAA,
42 BBBB,
43 BADC,
44 BACD,
45 ADDA,
46 CBCB,
47 ABDC,
48 ABAB,
49 DBBD,
50 CACA,
51 }
52
53 #[inline(always)]
shuffle_lanes(x: u64x4, control: Shuffle) -> u64x454 fn shuffle_lanes(x: u64x4, control: Shuffle) -> u64x4 {
55 unsafe {
56 use core::arch::x86_64::_mm256_permute4x64_epi64 as perm;
57
58 match control {
59 Shuffle::AAAA => perm(x.into_bits(), 0b00_00_00_00).into_bits(),
60 Shuffle::BBBB => perm(x.into_bits(), 0b01_01_01_01).into_bits(),
61 Shuffle::BADC => perm(x.into_bits(), 0b10_11_00_01).into_bits(),
62 Shuffle::BACD => perm(x.into_bits(), 0b11_10_00_01).into_bits(),
63 Shuffle::ADDA => perm(x.into_bits(), 0b00_11_11_00).into_bits(),
64 Shuffle::CBCB => perm(x.into_bits(), 0b01_10_01_10).into_bits(),
65 Shuffle::ABDC => perm(x.into_bits(), 0b10_11_01_00).into_bits(),
66 Shuffle::ABAB => perm(x.into_bits(), 0b01_00_01_00).into_bits(),
67 Shuffle::DBBD => perm(x.into_bits(), 0b11_01_01_11).into_bits(),
68 Shuffle::CACA => perm(x.into_bits(), 0b00_10_00_10).into_bits(),
69 }
70 }
71 }
72
73 #[derive(Copy, Clone)]
74 pub enum Lanes {
75 D,
76 C,
77 AB,
78 AC,
79 AD,
80 BCD,
81 }
82
83 #[inline]
blend_lanes(x: u64x4, y: u64x4, control: Lanes) -> u64x484 fn blend_lanes(x: u64x4, y: u64x4, control: Lanes) -> u64x4 {
85 unsafe {
86 use core::arch::x86_64::_mm256_blend_epi32 as blend;
87
88 match control {
89 Lanes::D => blend(x.into_bits(), y.into_bits(), 0b11_00_00_00).into_bits(),
90 Lanes::C => blend(x.into_bits(), y.into_bits(), 0b00_11_00_00).into_bits(),
91 Lanes::AB => blend(x.into_bits(), y.into_bits(), 0b00_00_11_11).into_bits(),
92 Lanes::AC => blend(x.into_bits(), y.into_bits(), 0b00_11_00_11).into_bits(),
93 Lanes::AD => blend(x.into_bits(), y.into_bits(), 0b11_00_00_11).into_bits(),
94 Lanes::BCD => blend(x.into_bits(), y.into_bits(), 0b11_11_11_00).into_bits(),
95 }
96 }
97 }
98
99 impl F51x4Unreduced {
zero() -> F51x4Unreduced100 pub fn zero() -> F51x4Unreduced {
101 F51x4Unreduced([u64x4::splat(0); 5])
102 }
103
new( x0: &FieldElement51, x1: &FieldElement51, x2: &FieldElement51, x3: &FieldElement51, ) -> F51x4Unreduced104 pub fn new(
105 x0: &FieldElement51,
106 x1: &FieldElement51,
107 x2: &FieldElement51,
108 x3: &FieldElement51,
109 ) -> F51x4Unreduced {
110 F51x4Unreduced([
111 u64x4::new(x0.0[0], x1.0[0], x2.0[0], x3.0[0]),
112 u64x4::new(x0.0[1], x1.0[1], x2.0[1], x3.0[1]),
113 u64x4::new(x0.0[2], x1.0[2], x2.0[2], x3.0[2]),
114 u64x4::new(x0.0[3], x1.0[3], x2.0[3], x3.0[3]),
115 u64x4::new(x0.0[4], x1.0[4], x2.0[4], x3.0[4]),
116 ])
117 }
118
split(&self) -> [FieldElement51; 4]119 pub fn split(&self) -> [FieldElement51; 4] {
120 let x = &self.0;
121 [
122 FieldElement51([
123 x[0].extract(0),
124 x[1].extract(0),
125 x[2].extract(0),
126 x[3].extract(0),
127 x[4].extract(0),
128 ]),
129 FieldElement51([
130 x[0].extract(1),
131 x[1].extract(1),
132 x[2].extract(1),
133 x[3].extract(1),
134 x[4].extract(1),
135 ]),
136 FieldElement51([
137 x[0].extract(2),
138 x[1].extract(2),
139 x[2].extract(2),
140 x[3].extract(2),
141 x[4].extract(2),
142 ]),
143 FieldElement51([
144 x[0].extract(3),
145 x[1].extract(3),
146 x[2].extract(3),
147 x[3].extract(3),
148 x[4].extract(3),
149 ]),
150 ]
151 }
152
153 #[inline]
diff_sum(&self) -> F51x4Unreduced154 pub fn diff_sum(&self) -> F51x4Unreduced {
155 // tmp1 = (B, A, D, C)
156 let tmp1 = self.shuffle(Shuffle::BADC);
157 // tmp2 = (-A, B, -C, D)
158 let tmp2 = self.blend(&self.negate_lazy(), Lanes::AC);
159 // (B - A, B + A, D - C, D + C)
160 tmp1 + tmp2
161 }
162
163 #[inline]
negate_lazy(&self) -> F51x4Unreduced164 pub fn negate_lazy(&self) -> F51x4Unreduced {
165 let lo = u64x4::splat(36028797018963664u64);
166 let hi = u64x4::splat(36028797018963952u64);
167 F51x4Unreduced([
168 lo - self.0[0],
169 hi - self.0[1],
170 hi - self.0[2],
171 hi - self.0[3],
172 hi - self.0[4],
173 ])
174 }
175
176 #[inline]
shuffle(&self, control: Shuffle) -> F51x4Unreduced177 pub fn shuffle(&self, control: Shuffle) -> F51x4Unreduced {
178 F51x4Unreduced([
179 shuffle_lanes(self.0[0], control),
180 shuffle_lanes(self.0[1], control),
181 shuffle_lanes(self.0[2], control),
182 shuffle_lanes(self.0[3], control),
183 shuffle_lanes(self.0[4], control),
184 ])
185 }
186
187 #[inline]
blend(&self, other: &F51x4Unreduced, control: Lanes) -> F51x4Unreduced188 pub fn blend(&self, other: &F51x4Unreduced, control: Lanes) -> F51x4Unreduced {
189 F51x4Unreduced([
190 blend_lanes(self.0[0], other.0[0], control),
191 blend_lanes(self.0[1], other.0[1], control),
192 blend_lanes(self.0[2], other.0[2], control),
193 blend_lanes(self.0[3], other.0[3], control),
194 blend_lanes(self.0[4], other.0[4], control),
195 ])
196 }
197 }
198
199 impl Neg for F51x4Reduced {
200 type Output = F51x4Reduced;
201
neg(self) -> F51x4Reduced202 fn neg(self) -> F51x4Reduced {
203 F51x4Unreduced::from(self).negate_lazy().into()
204 }
205 }
206
207 use subtle::Choice;
208 use subtle::ConditionallySelectable;
209
210 impl ConditionallySelectable for F51x4Reduced {
211 #[inline]
conditional_select(a: &F51x4Reduced, b: &F51x4Reduced, choice: Choice) -> F51x4Reduced212 fn conditional_select(a: &F51x4Reduced, b: &F51x4Reduced, choice: Choice) -> F51x4Reduced {
213 let mask = (-(choice.unwrap_u8() as i64)) as u64;
214 let mask_vec = u64x4::splat(mask);
215 F51x4Reduced([
216 a.0[0] ^ (mask_vec & (a.0[0] ^ b.0[0])),
217 a.0[1] ^ (mask_vec & (a.0[1] ^ b.0[1])),
218 a.0[2] ^ (mask_vec & (a.0[2] ^ b.0[2])),
219 a.0[3] ^ (mask_vec & (a.0[3] ^ b.0[3])),
220 a.0[4] ^ (mask_vec & (a.0[4] ^ b.0[4])),
221 ])
222 }
223
224 #[inline]
conditional_assign(&mut self, other: &F51x4Reduced, choice: Choice)225 fn conditional_assign(&mut self, other: &F51x4Reduced, choice: Choice) {
226 let mask = (-(choice.unwrap_u8() as i64)) as u64;
227 let mask_vec = u64x4::splat(mask);
228 self.0[0] ^= mask_vec & (self.0[0] ^ other.0[0]);
229 self.0[1] ^= mask_vec & (self.0[1] ^ other.0[1]);
230 self.0[2] ^= mask_vec & (self.0[2] ^ other.0[2]);
231 self.0[3] ^= mask_vec & (self.0[3] ^ other.0[3]);
232 self.0[4] ^= mask_vec & (self.0[4] ^ other.0[4]);
233 }
234 }
235
236 impl F51x4Reduced {
237 #[inline]
shuffle(&self, control: Shuffle) -> F51x4Reduced238 pub fn shuffle(&self, control: Shuffle) -> F51x4Reduced {
239 F51x4Reduced([
240 shuffle_lanes(self.0[0], control),
241 shuffle_lanes(self.0[1], control),
242 shuffle_lanes(self.0[2], control),
243 shuffle_lanes(self.0[3], control),
244 shuffle_lanes(self.0[4], control),
245 ])
246 }
247
248 #[inline]
blend(&self, other: &F51x4Reduced, control: Lanes) -> F51x4Reduced249 pub fn blend(&self, other: &F51x4Reduced, control: Lanes) -> F51x4Reduced {
250 F51x4Reduced([
251 blend_lanes(self.0[0], other.0[0], control),
252 blend_lanes(self.0[1], other.0[1], control),
253 blend_lanes(self.0[2], other.0[2], control),
254 blend_lanes(self.0[3], other.0[3], control),
255 blend_lanes(self.0[4], other.0[4], control),
256 ])
257 }
258
259 #[inline]
square(&self) -> F51x4Unreduced260 pub fn square(&self) -> F51x4Unreduced {
261 unsafe {
262 let x = &self.0;
263
264 // Represent values with coeff. 2
265 let mut z0_2 = u64x4::splat(0);
266 let mut z1_2 = u64x4::splat(0);
267 let mut z2_2 = u64x4::splat(0);
268 let mut z3_2 = u64x4::splat(0);
269 let mut z4_2 = u64x4::splat(0);
270 let mut z5_2 = u64x4::splat(0);
271 let mut z6_2 = u64x4::splat(0);
272 let mut z7_2 = u64x4::splat(0);
273 let mut z9_2 = u64x4::splat(0);
274
275 // Represent values with coeff. 4
276 let mut z2_4 = u64x4::splat(0);
277 let mut z3_4 = u64x4::splat(0);
278 let mut z4_4 = u64x4::splat(0);
279 let mut z5_4 = u64x4::splat(0);
280 let mut z6_4 = u64x4::splat(0);
281 let mut z7_4 = u64x4::splat(0);
282 let mut z8_4 = u64x4::splat(0);
283
284 let mut z0_1 = u64x4::splat(0);
285 z0_1 = madd52lo(z0_1, x[0], x[0]);
286
287 let mut z1_1 = u64x4::splat(0);
288 z1_2 = madd52lo(z1_2, x[0], x[1]);
289 z1_2 = madd52hi(z1_2, x[0], x[0]);
290
291 z2_4 = madd52hi(z2_4, x[0], x[1]);
292 let mut z2_1 = z2_4 << 2;
293 z2_2 = madd52lo(z2_2, x[0], x[2]);
294 z2_1 = madd52lo(z2_1, x[1], x[1]);
295
296 z3_4 = madd52hi(z3_4, x[0], x[2]);
297 let mut z3_1 = z3_4 << 2;
298 z3_2 = madd52lo(z3_2, x[1], x[2]);
299 z3_2 = madd52lo(z3_2, x[0], x[3]);
300 z3_2 = madd52hi(z3_2, x[1], x[1]);
301
302 z4_4 = madd52hi(z4_4, x[1], x[2]);
303 z4_4 = madd52hi(z4_4, x[0], x[3]);
304 let mut z4_1 = z4_4 << 2;
305 z4_2 = madd52lo(z4_2, x[1], x[3]);
306 z4_2 = madd52lo(z4_2, x[0], x[4]);
307 z4_1 = madd52lo(z4_1, x[2], x[2]);
308
309 z5_4 = madd52hi(z5_4, x[1], x[3]);
310 z5_4 = madd52hi(z5_4, x[0], x[4]);
311 let mut z5_1 = z5_4 << 2;
312 z5_2 = madd52lo(z5_2, x[2], x[3]);
313 z5_2 = madd52lo(z5_2, x[1], x[4]);
314 z5_2 = madd52hi(z5_2, x[2], x[2]);
315
316 z6_4 = madd52hi(z6_4, x[2], x[3]);
317 z6_4 = madd52hi(z6_4, x[1], x[4]);
318 let mut z6_1 = z6_4 << 2;
319 z6_2 = madd52lo(z6_2, x[2], x[4]);
320 z6_1 = madd52lo(z6_1, x[3], x[3]);
321
322 z7_4 = madd52hi(z7_4, x[2], x[4]);
323 let mut z7_1 = z7_4 << 2;
324 z7_2 = madd52lo(z7_2, x[3], x[4]);
325 z7_2 = madd52hi(z7_2, x[3], x[3]);
326
327 z8_4 = madd52hi(z8_4, x[3], x[4]);
328 let mut z8_1 = z8_4 << 2;
329 z8_1 = madd52lo(z8_1, x[4], x[4]);
330
331 let mut z9_1 = u64x4::splat(0);
332 z9_2 = madd52hi(z9_2, x[4], x[4]);
333
334 z5_1 += z5_2 << 1;
335 z6_1 += z6_2 << 1;
336 z7_1 += z7_2 << 1;
337 z9_1 += z9_2 << 1;
338
339 let mut t0 = u64x4::splat(0);
340 let mut t1 = u64x4::splat(0);
341 let r19 = u64x4::splat(19);
342
343 t0 = madd52hi(t0, r19, z9_1);
344 t1 = madd52lo(t1, r19, z9_1 >> 52);
345
346 z4_2 = madd52lo(z4_2, r19, z8_1 >> 52);
347 z3_2 = madd52lo(z3_2, r19, z7_1 >> 52);
348 z2_2 = madd52lo(z2_2, r19, z6_1 >> 52);
349 z1_2 = madd52lo(z1_2, r19, z5_1 >> 52);
350
351 z0_2 = madd52lo(z0_2, r19, t0 + t1);
352 z1_2 = madd52hi(z1_2, r19, z5_1);
353 z2_2 = madd52hi(z2_2, r19, z6_1);
354 z3_2 = madd52hi(z3_2, r19, z7_1);
355 z4_2 = madd52hi(z4_2, r19, z8_1);
356
357 z0_1 = madd52lo(z0_1, r19, z5_1);
358 z1_1 = madd52lo(z1_1, r19, z6_1);
359 z2_1 = madd52lo(z2_1, r19, z7_1);
360 z3_1 = madd52lo(z3_1, r19, z8_1);
361 z4_1 = madd52lo(z4_1, r19, z9_1);
362
363 F51x4Unreduced([
364 z0_1 + z0_2 + z0_2,
365 z1_1 + z1_2 + z1_2,
366 z2_1 + z2_2 + z2_2,
367 z3_1 + z3_2 + z3_2,
368 z4_1 + z4_2 + z4_2,
369 ])
370 }
371 }
372 }
373
374 impl From<F51x4Reduced> for F51x4Unreduced {
375 #[inline]
from(x: F51x4Reduced) -> F51x4Unreduced376 fn from(x: F51x4Reduced) -> F51x4Unreduced {
377 F51x4Unreduced(x.0)
378 }
379 }
380
381 impl From<F51x4Unreduced> for F51x4Reduced {
382 #[inline]
from(x: F51x4Unreduced) -> F51x4Reduced383 fn from(x: F51x4Unreduced) -> F51x4Reduced {
384 let mask = u64x4::splat((1 << 51) - 1);
385 let r19 = u64x4::splat(19);
386
387 // Compute carryouts in parallel
388 let c0 = x.0[0] >> 51;
389 let c1 = x.0[1] >> 51;
390 let c2 = x.0[2] >> 51;
391 let c3 = x.0[3] >> 51;
392 let c4 = x.0[4] >> 51;
393
394 unsafe {
395 F51x4Reduced([
396 madd52lo(x.0[0] & mask, c4, r19),
397 (x.0[1] & mask) + c0,
398 (x.0[2] & mask) + c1,
399 (x.0[3] & mask) + c2,
400 (x.0[4] & mask) + c3,
401 ])
402 }
403 }
404 }
405
406 impl Add<F51x4Unreduced> for F51x4Unreduced {
407 type Output = F51x4Unreduced;
408 #[inline]
add(self, rhs: F51x4Unreduced) -> F51x4Unreduced409 fn add(self, rhs: F51x4Unreduced) -> F51x4Unreduced {
410 F51x4Unreduced([
411 self.0[0] + rhs.0[0],
412 self.0[1] + rhs.0[1],
413 self.0[2] + rhs.0[2],
414 self.0[3] + rhs.0[3],
415 self.0[4] + rhs.0[4],
416 ])
417 }
418 }
419
420 impl<'a> Mul<(u32, u32, u32, u32)> for &'a F51x4Reduced {
421 type Output = F51x4Unreduced;
422 #[inline]
mul(self, scalars: (u32, u32, u32, u32)) -> F51x4Unreduced423 fn mul(self, scalars: (u32, u32, u32, u32)) -> F51x4Unreduced {
424 unsafe {
425 let x = &self.0;
426 let y = u64x4::new(
427 scalars.0 as u64,
428 scalars.1 as u64,
429 scalars.2 as u64,
430 scalars.3 as u64,
431 );
432 let r19 = u64x4::splat(19);
433
434 let mut z0_1 = u64x4::splat(0);
435 let mut z1_1 = u64x4::splat(0);
436 let mut z2_1 = u64x4::splat(0);
437 let mut z3_1 = u64x4::splat(0);
438 let mut z4_1 = u64x4::splat(0);
439 let mut z1_2 = u64x4::splat(0);
440 let mut z2_2 = u64x4::splat(0);
441 let mut z3_2 = u64x4::splat(0);
442 let mut z4_2 = u64x4::splat(0);
443 let mut z5_2 = u64x4::splat(0);
444
445 // Wave 0
446 z4_2 = madd52hi(z4_2, y, x[3]);
447 z5_2 = madd52hi(z5_2, y, x[4]);
448 z4_1 = madd52lo(z4_1, y, x[4]);
449 z0_1 = madd52lo(z0_1, y, x[0]);
450 z3_1 = madd52lo(z3_1, y, x[3]);
451 z2_1 = madd52lo(z2_1, y, x[2]);
452 z1_1 = madd52lo(z1_1, y, x[1]);
453 z3_2 = madd52hi(z3_2, y, x[2]);
454
455 // Wave 2
456 z2_2 = madd52hi(z2_2, y, x[1]);
457 z1_2 = madd52hi(z1_2, y, x[0]);
458 z0_1 = madd52lo(z0_1, z5_2 + z5_2, r19);
459
460 F51x4Unreduced([
461 z0_1,
462 z1_1 + z1_2 + z1_2,
463 z2_1 + z2_2 + z2_2,
464 z3_1 + z3_2 + z3_2,
465 z4_1 + z4_2 + z4_2,
466 ])
467 }
468 }
469 }
470
471 impl<'a, 'b> Mul<&'b F51x4Reduced> for &'a F51x4Reduced {
472 type Output = F51x4Unreduced;
473 #[inline]
mul(self, rhs: &'b F51x4Reduced) -> F51x4Unreduced474 fn mul(self, rhs: &'b F51x4Reduced) -> F51x4Unreduced {
475 unsafe {
476 // Inputs
477 let x = &self.0;
478 let y = &rhs.0;
479
480 // Accumulators for terms with coeff 1
481 let mut z0_1 = u64x4::splat(0);
482 let mut z1_1 = u64x4::splat(0);
483 let mut z2_1 = u64x4::splat(0);
484 let mut z3_1 = u64x4::splat(0);
485 let mut z4_1 = u64x4::splat(0);
486 let mut z5_1 = u64x4::splat(0);
487 let mut z6_1 = u64x4::splat(0);
488 let mut z7_1 = u64x4::splat(0);
489 let mut z8_1 = u64x4::splat(0);
490
491 // Accumulators for terms with coeff 2
492 let mut z0_2 = u64x4::splat(0);
493 let mut z1_2 = u64x4::splat(0);
494 let mut z2_2 = u64x4::splat(0);
495 let mut z3_2 = u64x4::splat(0);
496 let mut z4_2 = u64x4::splat(0);
497 let mut z5_2 = u64x4::splat(0);
498 let mut z6_2 = u64x4::splat(0);
499 let mut z7_2 = u64x4::splat(0);
500 let mut z8_2 = u64x4::splat(0);
501 let mut z9_2 = u64x4::splat(0);
502
503 // LLVM doesn't seem to do much work reordering IFMA
504 // instructions, so try to organize them into "waves" of 8
505 // independent operations (4c latency, 0.5 c throughput
506 // means 8 in flight)
507
508 // Wave 0
509 z4_1 = madd52lo(z4_1, x[2], y[2]);
510 z5_2 = madd52hi(z5_2, x[2], y[2]);
511 z5_1 = madd52lo(z5_1, x[4], y[1]);
512 z6_2 = madd52hi(z6_2, x[4], y[1]);
513 z6_1 = madd52lo(z6_1, x[4], y[2]);
514 z7_2 = madd52hi(z7_2, x[4], y[2]);
515 z7_1 = madd52lo(z7_1, x[4], y[3]);
516 z8_2 = madd52hi(z8_2, x[4], y[3]);
517
518 // Wave 1
519 z4_1 = madd52lo(z4_1, x[3], y[1]);
520 z5_2 = madd52hi(z5_2, x[3], y[1]);
521 z5_1 = madd52lo(z5_1, x[3], y[2]);
522 z6_2 = madd52hi(z6_2, x[3], y[2]);
523 z6_1 = madd52lo(z6_1, x[3], y[3]);
524 z7_2 = madd52hi(z7_2, x[3], y[3]);
525 z7_1 = madd52lo(z7_1, x[3], y[4]);
526 z8_2 = madd52hi(z8_2, x[3], y[4]);
527
528 // Wave 2
529 z8_1 = madd52lo(z8_1, x[4], y[4]);
530 z9_2 = madd52hi(z9_2, x[4], y[4]);
531 z4_1 = madd52lo(z4_1, x[4], y[0]);
532 z5_2 = madd52hi(z5_2, x[4], y[0]);
533 z5_1 = madd52lo(z5_1, x[2], y[3]);
534 z6_2 = madd52hi(z6_2, x[2], y[3]);
535 z6_1 = madd52lo(z6_1, x[2], y[4]);
536 z7_2 = madd52hi(z7_2, x[2], y[4]);
537
538 let z8 = z8_1 + z8_2 + z8_2;
539 let z9 = z9_2 + z9_2;
540
541 // Wave 3
542 z3_1 = madd52lo(z3_1, x[3], y[0]);
543 z4_2 = madd52hi(z4_2, x[3], y[0]);
544 z4_1 = madd52lo(z4_1, x[1], y[3]);
545 z5_2 = madd52hi(z5_2, x[1], y[3]);
546 z5_1 = madd52lo(z5_1, x[1], y[4]);
547 z6_2 = madd52hi(z6_2, x[1], y[4]);
548 z2_1 = madd52lo(z2_1, x[2], y[0]);
549 z3_2 = madd52hi(z3_2, x[2], y[0]);
550
551 let z6 = z6_1 + z6_2 + z6_2;
552 let z7 = z7_1 + z7_2 + z7_2;
553
554 // Wave 4
555 z3_1 = madd52lo(z3_1, x[2], y[1]);
556 z4_2 = madd52hi(z4_2, x[2], y[1]);
557 z4_1 = madd52lo(z4_1, x[0], y[4]);
558 z5_2 = madd52hi(z5_2, x[0], y[4]);
559 z1_1 = madd52lo(z1_1, x[1], y[0]);
560 z2_2 = madd52hi(z2_2, x[1], y[0]);
561 z2_1 = madd52lo(z2_1, x[1], y[1]);
562 z3_2 = madd52hi(z3_2, x[1], y[1]);
563
564 let z5 = z5_1 + z5_2 + z5_2;
565
566 // Wave 5
567 z3_1 = madd52lo(z3_1, x[1], y[2]);
568 z4_2 = madd52hi(z4_2, x[1], y[2]);
569 z0_1 = madd52lo(z0_1, x[0], y[0]);
570 z1_2 = madd52hi(z1_2, x[0], y[0]);
571 z1_1 = madd52lo(z1_1, x[0], y[1]);
572 z2_1 = madd52lo(z2_1, x[0], y[2]);
573 z2_2 = madd52hi(z2_2, x[0], y[1]);
574 z3_2 = madd52hi(z3_2, x[0], y[2]);
575
576 let mut t0 = u64x4::splat(0);
577 let mut t1 = u64x4::splat(0);
578 let r19 = u64x4::splat(19);
579
580 // Wave 6
581 t0 = madd52hi(t0, r19, z9);
582 t1 = madd52lo(t1, r19, z9 >> 52);
583 z3_1 = madd52lo(z3_1, x[0], y[3]);
584 z4_2 = madd52hi(z4_2, x[0], y[3]);
585 z1_2 = madd52lo(z1_2, r19, z5 >> 52);
586 z2_2 = madd52lo(z2_2, r19, z6 >> 52);
587 z3_2 = madd52lo(z3_2, r19, z7 >> 52);
588 z0_1 = madd52lo(z0_1, r19, z5);
589
590 // Wave 7
591 z4_1 = madd52lo(z4_1, r19, z9);
592 z1_1 = madd52lo(z1_1, r19, z6);
593 z0_2 = madd52lo(z0_2, r19, t0 + t1);
594 z4_2 = madd52hi(z4_2, r19, z8);
595 z2_1 = madd52lo(z2_1, r19, z7);
596 z1_2 = madd52hi(z1_2, r19, z5);
597 z2_2 = madd52hi(z2_2, r19, z6);
598 z3_2 = madd52hi(z3_2, r19, z7);
599
600 // Wave 8
601 z3_1 = madd52lo(z3_1, r19, z8);
602 z4_2 = madd52lo(z4_2, r19, z8 >> 52);
603
604 F51x4Unreduced([
605 z0_1 + z0_2 + z0_2,
606 z1_1 + z1_2 + z1_2,
607 z2_1 + z2_2 + z2_2,
608 z3_1 + z3_2 + z3_2,
609 z4_1 + z4_2 + z4_2,
610 ])
611 }
612 }
613 }
614
615 #[cfg(test)]
616 mod test {
617 use super::*;
618
619 #[test]
vpmadd52luq()620 fn vpmadd52luq() {
621 let x = u64x4::splat(2);
622 let y = u64x4::splat(3);
623 let mut z = u64x4::splat(5);
624
625 z = unsafe { madd52lo(z, x, y) };
626
627 assert_eq!(z, u64x4::splat(5 + 2 * 3));
628 }
629
630 #[test]
new_split_round_trip_on_reduced_input()631 fn new_split_round_trip_on_reduced_input() {
632 // Invert a small field element to get a big one
633 let a = FieldElement51([2438, 24, 243, 0, 0]).invert();
634
635 let ax4 = F51x4Unreduced::new(&a, &a, &a, &a);
636 let splits = ax4.split();
637
638 for i in 0..4 {
639 assert_eq!(a, splits[i]);
640 }
641 }
642
643 #[test]
new_split_round_trip_on_unreduced_input()644 fn new_split_round_trip_on_unreduced_input() {
645 // Invert a small field element to get a big one
646 let a = FieldElement51([2438, 24, 243, 0, 0]).invert();
647 // ... but now multiply it by 16 without reducing coeffs
648 let a16 = FieldElement51([
649 a.0[0] << 4,
650 a.0[1] << 4,
651 a.0[2] << 4,
652 a.0[3] << 4,
653 a.0[4] << 4,
654 ]);
655
656 let a16x4 = F51x4Unreduced::new(&a16, &a16, &a16, &a16);
657 let splits = a16x4.split();
658
659 for i in 0..4 {
660 assert_eq!(a16, splits[i]);
661 }
662 }
663
664 #[test]
test_reduction()665 fn test_reduction() {
666 // Invert a small field element to get a big one
667 let a = FieldElement51([2438, 24, 243, 0, 0]).invert();
668 // ... but now multiply it by 128 without reducing coeffs
669 let abig = FieldElement51([
670 a.0[0] << 4,
671 a.0[1] << 4,
672 a.0[2] << 4,
673 a.0[3] << 4,
674 a.0[4] << 4,
675 ]);
676
677 let abigx4: F51x4Reduced = F51x4Unreduced::new(&abig, &abig, &abig, &abig).into();
678
679 let splits = F51x4Unreduced::from(abigx4).split();
680 let c = &a * &FieldElement51([(1 << 4), 0, 0, 0, 0]);
681
682 for i in 0..4 {
683 assert_eq!(c, splits[i]);
684 }
685 }
686
687 #[test]
mul_matches_serial()688 fn mul_matches_serial() {
689 // Invert a small field element to get a big one
690 let a = FieldElement51([2438, 24, 243, 0, 0]).invert();
691 let b = FieldElement51([98098, 87987897, 0, 1, 0]).invert();
692 let c = &a * &b;
693
694 let ax4: F51x4Reduced = F51x4Unreduced::new(&a, &a, &a, &a).into();
695 let bx4: F51x4Reduced = F51x4Unreduced::new(&b, &b, &b, &b).into();
696 let cx4 = &ax4 * &bx4;
697
698 let splits = cx4.split();
699
700 for i in 0..4 {
701 assert_eq!(c, splits[i]);
702 }
703 }
704
705 #[test]
iterated_mul_matches_serial()706 fn iterated_mul_matches_serial() {
707 // Invert a small field element to get a big one
708 let a = FieldElement51([2438, 24, 243, 0, 0]).invert();
709 let b = FieldElement51([98098, 87987897, 0, 1, 0]).invert();
710 let mut c = &a * &b;
711 for _i in 0..1024 {
712 c = &a * &c;
713 c = &b * &c;
714 }
715
716 let ax4: F51x4Reduced = F51x4Unreduced::new(&a, &a, &a, &a).into();
717 let bx4: F51x4Reduced = F51x4Unreduced::new(&b, &b, &b, &b).into();
718 let mut cx4 = &ax4 * &bx4;
719 for _i in 0..1024 {
720 cx4 = &ax4 * &F51x4Reduced::from(cx4);
721 cx4 = &bx4 * &F51x4Reduced::from(cx4);
722 }
723
724 let splits = cx4.split();
725
726 for i in 0..4 {
727 assert_eq!(c, splits[i]);
728 }
729 }
730
731 #[test]
square_matches_mul()732 fn square_matches_mul() {
733 // Invert a small field element to get a big one
734 let a = FieldElement51([2438, 24, 243, 0, 0]).invert();
735
736 let ax4: F51x4Reduced = F51x4Unreduced::new(&a, &a, &a, &a).into();
737 let cx4 = &ax4 * &ax4;
738 let cx4_sq = ax4.square();
739
740 let splits = cx4.split();
741 let splits_sq = cx4_sq.split();
742
743 for i in 0..4 {
744 assert_eq!(splits_sq[i], splits[i]);
745 }
746 }
747
748 #[test]
iterated_square_matches_serial()749 fn iterated_square_matches_serial() {
750 // Invert a small field element to get a big one
751 let mut a = FieldElement51([2438, 24, 243, 0, 0]).invert();
752 let mut ax4 = F51x4Unreduced::new(&a, &a, &a, &a);
753 for _j in 0..1024 {
754 a = a.square();
755 ax4 = F51x4Reduced::from(ax4).square();
756
757 let splits = ax4.split();
758 for i in 0..4 {
759 assert_eq!(a, splits[i]);
760 }
761 }
762 }
763
764 #[test]
iterated_u32_mul_matches_serial()765 fn iterated_u32_mul_matches_serial() {
766 // Invert a small field element to get a big one
767 let a = FieldElement51([2438, 24, 243, 0, 0]).invert();
768 let b = FieldElement51([121665, 0, 0, 0, 0]);
769 let mut c = &a * &b;
770 for _i in 0..1024 {
771 c = &b * &c;
772 }
773
774 let ax4 = F51x4Unreduced::new(&a, &a, &a, &a);
775 let bx4 = (121665u32, 121665u32, 121665u32, 121665u32);
776 let mut cx4 = &F51x4Reduced::from(ax4) * bx4;
777 for _i in 0..1024 {
778 cx4 = &F51x4Reduced::from(cx4) * bx4;
779 }
780
781 let splits = cx4.split();
782
783 for i in 0..4 {
784 assert_eq!(c, splits[i]);
785 }
786 }
787
788 #[test]
shuffle_AAAA()789 fn shuffle_AAAA() {
790 let x0 = FieldElement51::from_bytes(&[0x10; 32]);
791 let x1 = FieldElement51::from_bytes(&[0x11; 32]);
792 let x2 = FieldElement51::from_bytes(&[0x12; 32]);
793 let x3 = FieldElement51::from_bytes(&[0x13; 32]);
794
795 let x = F51x4Unreduced::new(&x0, &x1, &x2, &x3);
796
797 let y = x.shuffle(Shuffle::AAAA);
798 let splits = y.split();
799
800 assert_eq!(splits[0], x0);
801 assert_eq!(splits[1], x0);
802 assert_eq!(splits[2], x0);
803 assert_eq!(splits[3], x0);
804 }
805
806 #[test]
blend_AB()807 fn blend_AB() {
808 let x0 = FieldElement51::from_bytes(&[0x10; 32]);
809 let x1 = FieldElement51::from_bytes(&[0x11; 32]);
810 let x2 = FieldElement51::from_bytes(&[0x12; 32]);
811 let x3 = FieldElement51::from_bytes(&[0x13; 32]);
812
813 let x = F51x4Unreduced::new(&x0, &x1, &x2, &x3);
814 let z = F51x4Unreduced::new(&x3, &x2, &x1, &x0);
815
816 let y = x.blend(&z, Lanes::AB);
817 let splits = y.split();
818
819 assert_eq!(splits[0], x3);
820 assert_eq!(splits[1], x2);
821 assert_eq!(splits[2], x2);
822 assert_eq!(splits[3], x3);
823 }
824 }
825