1// Copyright (c) 2017 The Go Authors. All rights reserved. 2// Use of this source code is governed by a BSD-style 3// license that can be found in the LICENSE file. 4 5package field 6 7import ( 8 "bytes" 9 "crypto/rand" 10 "encoding/hex" 11 "io" 12 "math/big" 13 "math/bits" 14 mathrand "math/rand" 15 "reflect" 16 "testing" 17 "testing/quick" 18) 19 20func (v Element) String() string { 21 return hex.EncodeToString(v.Bytes()) 22} 23 24// quickCheckConfig1024 will make each quickcheck test run (1024 * -quickchecks) 25// times. The default value of -quickchecks is 100. 26var quickCheckConfig1024 = &quick.Config{MaxCountScale: 1 << 10} 27 28func generateFieldElement(rand *mathrand.Rand) Element { 29 const maskLow52Bits = (1 << 52) - 1 30 return Element{ 31 rand.Uint64() & maskLow52Bits, 32 rand.Uint64() & maskLow52Bits, 33 rand.Uint64() & maskLow52Bits, 34 rand.Uint64() & maskLow52Bits, 35 rand.Uint64() & maskLow52Bits, 36 } 37} 38 39// weirdLimbs can be combined to generate a range of edge-case field elements. 40// 0 and -1 are intentionally more weighted, as they combine well. 41var ( 42 weirdLimbs51 = []uint64{ 43 0, 0, 0, 0, 44 1, 45 19 - 1, 46 19, 47 0x2aaaaaaaaaaaa, 48 0x5555555555555, 49 (1 << 51) - 20, 50 (1 << 51) - 19, 51 (1 << 51) - 1, (1 << 51) - 1, 52 (1 << 51) - 1, (1 << 51) - 1, 53 } 54 weirdLimbs52 = []uint64{ 55 0, 0, 0, 0, 0, 0, 56 1, 57 19 - 1, 58 19, 59 0x2aaaaaaaaaaaa, 60 0x5555555555555, 61 (1 << 51) - 20, 62 (1 << 51) - 19, 63 (1 << 51) - 1, (1 << 51) - 1, 64 (1 << 51) - 1, (1 << 51) - 1, 65 (1 << 51) - 1, (1 << 51) - 1, 66 1 << 51, 67 (1 << 51) + 1, 68 (1 << 52) - 19, 69 (1 << 52) - 1, 70 } 71) 72 73func generateWeirdFieldElement(rand *mathrand.Rand) Element { 74 return Element{ 75 weirdLimbs52[rand.Intn(len(weirdLimbs52))], 76 weirdLimbs51[rand.Intn(len(weirdLimbs51))], 77 weirdLimbs51[rand.Intn(len(weirdLimbs51))], 78 weirdLimbs51[rand.Intn(len(weirdLimbs51))], 79 weirdLimbs51[rand.Intn(len(weirdLimbs51))], 80 } 81} 82 83func (Element) Generate(rand *mathrand.Rand, size int) reflect.Value { 84 if rand.Intn(2) == 0 { 85 return reflect.ValueOf(generateWeirdFieldElement(rand)) 86 } 87 return reflect.ValueOf(generateFieldElement(rand)) 88} 89 90// isInBounds returns whether the element is within the expected bit size bounds 91// after a light reduction. 92func isInBounds(x *Element) bool { 93 return bits.Len64(x.l0) <= 52 && 94 bits.Len64(x.l1) <= 52 && 95 bits.Len64(x.l2) <= 52 && 96 bits.Len64(x.l3) <= 52 && 97 bits.Len64(x.l4) <= 52 98} 99 100func TestMultiplyDistributesOverAdd(t *testing.T) { 101 multiplyDistributesOverAdd := func(x, y, z Element) bool { 102 // Compute t1 = (x+y)*z 103 t1 := new(Element) 104 t1.Add(&x, &y) 105 t1.Multiply(t1, &z) 106 107 // Compute t2 = x*z + y*z 108 t2 := new(Element) 109 t3 := new(Element) 110 t2.Multiply(&x, &z) 111 t3.Multiply(&y, &z) 112 t2.Add(t2, t3) 113 114 return t1.Equal(t2) == 1 && isInBounds(t1) && isInBounds(t2) 115 } 116 117 if err := quick.Check(multiplyDistributesOverAdd, quickCheckConfig1024); err != nil { 118 t.Error(err) 119 } 120} 121 122func TestMul64to128(t *testing.T) { 123 a := uint64(5) 124 b := uint64(5) 125 r := mul64(a, b) 126 if r.lo != 0x19 || r.hi != 0 { 127 t.Errorf("lo-range wide mult failed, got %d + %d*(2**64)", r.lo, r.hi) 128 } 129 130 a = uint64(18014398509481983) // 2^54 - 1 131 b = uint64(18014398509481983) // 2^54 - 1 132 r = mul64(a, b) 133 if r.lo != 0xff80000000000001 || r.hi != 0xfffffffffff { 134 t.Errorf("hi-range wide mult failed, got %d + %d*(2**64)", r.lo, r.hi) 135 } 136 137 a = uint64(1125899906842661) 138 b = uint64(2097155) 139 r = mul64(a, b) 140 r = addMul64(r, a, b) 141 r = addMul64(r, a, b) 142 r = addMul64(r, a, b) 143 r = addMul64(r, a, b) 144 if r.lo != 16888498990613035 || r.hi != 640 { 145 t.Errorf("wrong answer: %d + %d*(2**64)", r.lo, r.hi) 146 } 147} 148 149func TestSetBytesRoundTrip(t *testing.T) { 150 f1 := func(in [32]byte, fe Element) bool { 151 fe.SetBytes(in[:]) 152 153 // Mask the most significant bit as it's ignored by SetBytes. (Now 154 // instead of earlier so we check the masking in SetBytes is working.) 155 in[len(in)-1] &= (1 << 7) - 1 156 157 return bytes.Equal(in[:], fe.Bytes()) && isInBounds(&fe) 158 } 159 if err := quick.Check(f1, nil); err != nil { 160 t.Errorf("failed bytes->FE->bytes round-trip: %v", err) 161 } 162 163 f2 := func(fe, r Element) bool { 164 r.SetBytes(fe.Bytes()) 165 166 // Intentionally not using Equal not to go through Bytes again. 167 // Calling reduce because both Generate and SetBytes can produce 168 // non-canonical representations. 169 fe.reduce() 170 r.reduce() 171 return fe == r 172 } 173 if err := quick.Check(f2, nil); err != nil { 174 t.Errorf("failed FE->bytes->FE round-trip: %v", err) 175 } 176 177 // Check some fixed vectors from dalek 178 type feRTTest struct { 179 fe Element 180 b []byte 181 } 182 var tests = []feRTTest{ 183 { 184 fe: Element{358744748052810, 1691584618240980, 977650209285361, 1429865912637724, 560044844278676}, 185 b: []byte{74, 209, 69, 197, 70, 70, 161, 222, 56, 226, 229, 19, 112, 60, 25, 92, 187, 74, 222, 56, 50, 153, 51, 233, 40, 74, 57, 6, 160, 185, 213, 31}, 186 }, 187 { 188 fe: Element{84926274344903, 473620666599931, 365590438845504, 1028470286882429, 2146499180330972}, 189 b: []byte{199, 23, 106, 112, 61, 77, 216, 79, 186, 60, 11, 118, 13, 16, 103, 15, 42, 32, 83, 250, 44, 57, 204, 198, 78, 199, 253, 119, 146, 172, 3, 122}, 190 }, 191 } 192 193 for _, tt := range tests { 194 b := tt.fe.Bytes() 195 if !bytes.Equal(b, tt.b) || new(Element).SetBytes(tt.b).Equal(&tt.fe) != 1 { 196 t.Errorf("Failed fixed roundtrip: %v", tt) 197 } 198 } 199} 200 201func swapEndianness(buf []byte) []byte { 202 for i := 0; i < len(buf)/2; i++ { 203 buf[i], buf[len(buf)-i-1] = buf[len(buf)-i-1], buf[i] 204 } 205 return buf 206} 207 208func TestBytesBigEquivalence(t *testing.T) { 209 f1 := func(in [32]byte, fe, fe1 Element) bool { 210 fe.SetBytes(in[:]) 211 212 in[len(in)-1] &= (1 << 7) - 1 // mask the most significant bit 213 b := new(big.Int).SetBytes(swapEndianness(in[:])) 214 fe1.fromBig(b) 215 216 if fe != fe1 { 217 return false 218 } 219 220 buf := make([]byte, 32) // pad with zeroes 221 copy(buf, swapEndianness(fe1.toBig().Bytes())) 222 223 return bytes.Equal(fe.Bytes(), buf) && isInBounds(&fe) && isInBounds(&fe1) 224 } 225 if err := quick.Check(f1, nil); err != nil { 226 t.Error(err) 227 } 228} 229 230// fromBig sets v = n, and returns v. The bit length of n must not exceed 256. 231func (v *Element) fromBig(n *big.Int) *Element { 232 if n.BitLen() > 32*8 { 233 panic("edwards25519: invalid field element input size") 234 } 235 236 buf := make([]byte, 0, 32) 237 for _, word := range n.Bits() { 238 for i := 0; i < bits.UintSize; i += 8 { 239 if len(buf) >= cap(buf) { 240 break 241 } 242 buf = append(buf, byte(word)) 243 word >>= 8 244 } 245 } 246 247 return v.SetBytes(buf[:32]) 248} 249 250func (v *Element) fromDecimal(s string) *Element { 251 n, ok := new(big.Int).SetString(s, 10) 252 if !ok { 253 panic("not a valid decimal: " + s) 254 } 255 return v.fromBig(n) 256} 257 258// toBig returns v as a big.Int. 259func (v *Element) toBig() *big.Int { 260 buf := v.Bytes() 261 262 words := make([]big.Word, 32*8/bits.UintSize) 263 for n := range words { 264 for i := 0; i < bits.UintSize; i += 8 { 265 if len(buf) == 0 { 266 break 267 } 268 words[n] |= big.Word(buf[0]) << big.Word(i) 269 buf = buf[1:] 270 } 271 } 272 273 return new(big.Int).SetBits(words) 274} 275 276func TestDecimalConstants(t *testing.T) { 277 sqrtM1String := "19681161376707505956807079304988542015446066515923890162744021073123829784752" 278 if exp := new(Element).fromDecimal(sqrtM1String); sqrtM1.Equal(exp) != 1 { 279 t.Errorf("sqrtM1 is %v, expected %v", sqrtM1, exp) 280 } 281 // d is in the parent package, and we don't want to expose d or fromDecimal. 282 // dString := "37095705934669439343138083508754565189542113879843219016388785533085940283555" 283 // if exp := new(Element).fromDecimal(dString); d.Equal(exp) != 1 { 284 // t.Errorf("d is %v, expected %v", d, exp) 285 // } 286} 287 288func TestSetBytesRoundTripEdgeCases(t *testing.T) { 289 // TODO: values close to 0, close to 2^255-19, between 2^255-19 and 2^255-1, 290 // and between 2^255 and 2^256-1. Test both the documented SetBytes 291 // behavior, and that Bytes reduces them. 292} 293 294// Tests self-consistency between Multiply and Square. 295func TestConsistency(t *testing.T) { 296 var x Element 297 var x2, x2sq Element 298 299 x = Element{1, 1, 1, 1, 1} 300 x2.Multiply(&x, &x) 301 x2sq.Square(&x) 302 303 if x2 != x2sq { 304 t.Fatalf("all ones failed\nmul: %x\nsqr: %x\n", x2, x2sq) 305 } 306 307 var bytes [32]byte 308 309 _, err := io.ReadFull(rand.Reader, bytes[:]) 310 if err != nil { 311 t.Fatal(err) 312 } 313 x.SetBytes(bytes[:]) 314 315 x2.Multiply(&x, &x) 316 x2sq.Square(&x) 317 318 if x2 != x2sq { 319 t.Fatalf("all ones failed\nmul: %x\nsqr: %x\n", x2, x2sq) 320 } 321} 322 323func TestEqual(t *testing.T) { 324 x := Element{1, 1, 1, 1, 1} 325 y := Element{5, 4, 3, 2, 1} 326 327 eq := x.Equal(&x) 328 if eq != 1 { 329 t.Errorf("wrong about equality") 330 } 331 332 eq = x.Equal(&y) 333 if eq != 0 { 334 t.Errorf("wrong about inequality") 335 } 336} 337 338func TestInvert(t *testing.T) { 339 x := Element{1, 1, 1, 1, 1} 340 one := Element{1, 0, 0, 0, 0} 341 var xinv, r Element 342 343 xinv.Invert(&x) 344 r.Multiply(&x, &xinv) 345 r.reduce() 346 347 if one != r { 348 t.Errorf("inversion identity failed, got: %x", r) 349 } 350 351 var bytes [32]byte 352 353 _, err := io.ReadFull(rand.Reader, bytes[:]) 354 if err != nil { 355 t.Fatal(err) 356 } 357 x.SetBytes(bytes[:]) 358 359 xinv.Invert(&x) 360 r.Multiply(&x, &xinv) 361 r.reduce() 362 363 if one != r { 364 t.Errorf("random inversion identity failed, got: %x for field element %x", r, x) 365 } 366 367 zero := Element{} 368 x.Set(&zero) 369 if xx := xinv.Invert(&x); xx != &xinv { 370 t.Errorf("inverting zero did not return the receiver") 371 } else if xinv.Equal(&zero) != 1 { 372 t.Errorf("inverting zero did not return zero") 373 } 374} 375 376func TestSelectSwap(t *testing.T) { 377 a := Element{358744748052810, 1691584618240980, 977650209285361, 1429865912637724, 560044844278676} 378 b := Element{84926274344903, 473620666599931, 365590438845504, 1028470286882429, 2146499180330972} 379 380 var c, d Element 381 382 c.Select(&a, &b, 1) 383 d.Select(&a, &b, 0) 384 385 if c.Equal(&a) != 1 || d.Equal(&b) != 1 { 386 t.Errorf("Select failed") 387 } 388 389 c.Swap(&d, 0) 390 391 if c.Equal(&a) != 1 || d.Equal(&b) != 1 { 392 t.Errorf("Swap failed") 393 } 394 395 c.Swap(&d, 1) 396 397 if c.Equal(&b) != 1 || d.Equal(&a) != 1 { 398 t.Errorf("Swap failed") 399 } 400} 401 402func TestMult32(t *testing.T) { 403 mult32EquivalentToMul := func(x Element, y uint32) bool { 404 t1 := new(Element) 405 for i := 0; i < 100; i++ { 406 t1.Mult32(&x, y) 407 } 408 409 ty := new(Element) 410 ty.l0 = uint64(y) 411 412 t2 := new(Element) 413 for i := 0; i < 100; i++ { 414 t2.Multiply(&x, ty) 415 } 416 417 return t1.Equal(t2) == 1 && isInBounds(t1) && isInBounds(t2) 418 } 419 420 if err := quick.Check(mult32EquivalentToMul, quickCheckConfig1024); err != nil { 421 t.Error(err) 422 } 423} 424 425func TestSqrtRatio(t *testing.T) { 426 // From draft-irtf-cfrg-ristretto255-decaf448-00, Appendix A.4. 427 type test struct { 428 u, v string 429 wasSquare int 430 r string 431 } 432 var tests = []test{ 433 // If u is 0, the function is defined to return (0, TRUE), even if v 434 // is zero. Note that where used in this package, the denominator v 435 // is never zero. 436 { 437 "0000000000000000000000000000000000000000000000000000000000000000", 438 "0000000000000000000000000000000000000000000000000000000000000000", 439 1, "0000000000000000000000000000000000000000000000000000000000000000", 440 }, 441 // 0/1 == 0² 442 { 443 "0000000000000000000000000000000000000000000000000000000000000000", 444 "0100000000000000000000000000000000000000000000000000000000000000", 445 1, "0000000000000000000000000000000000000000000000000000000000000000", 446 }, 447 // If u is non-zero and v is zero, defined to return (0, FALSE). 448 { 449 "0100000000000000000000000000000000000000000000000000000000000000", 450 "0000000000000000000000000000000000000000000000000000000000000000", 451 0, "0000000000000000000000000000000000000000000000000000000000000000", 452 }, 453 // 2/1 is not square in this field. 454 { 455 "0200000000000000000000000000000000000000000000000000000000000000", 456 "0100000000000000000000000000000000000000000000000000000000000000", 457 0, "3c5ff1b5d8e4113b871bd052f9e7bcd0582804c266ffb2d4f4203eb07fdb7c54", 458 }, 459 // 4/1 == 2² 460 { 461 "0400000000000000000000000000000000000000000000000000000000000000", 462 "0100000000000000000000000000000000000000000000000000000000000000", 463 1, "0200000000000000000000000000000000000000000000000000000000000000", 464 }, 465 // 1/4 == (2⁻¹)² == (2^(p-2))² per Euler's theorem 466 { 467 "0100000000000000000000000000000000000000000000000000000000000000", 468 "0400000000000000000000000000000000000000000000000000000000000000", 469 1, "f6ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff3f", 470 }, 471 } 472 473 for i, tt := range tests { 474 u := new(Element).SetBytes(decodeHex(tt.u)) 475 v := new(Element).SetBytes(decodeHex(tt.v)) 476 want := new(Element).SetBytes(decodeHex(tt.r)) 477 got, wasSquare := new(Element).SqrtRatio(u, v) 478 if got.Equal(want) == 0 || wasSquare != tt.wasSquare { 479 t.Errorf("%d: got (%v, %v), want (%v, %v)", i, got, wasSquare, want, tt.wasSquare) 480 } 481 } 482} 483 484func TestCarryPropagate(t *testing.T) { 485 asmLikeGeneric := func(a [5]uint64) bool { 486 t1 := &Element{a[0], a[1], a[2], a[3], a[4]} 487 t2 := &Element{a[0], a[1], a[2], a[3], a[4]} 488 489 t1.carryPropagate() 490 t2.carryPropagateGeneric() 491 492 if *t1 != *t2 { 493 t.Logf("got: %#v,\nexpected: %#v", t1, t2) 494 } 495 496 return *t1 == *t2 && isInBounds(t2) 497 } 498 499 if err := quick.Check(asmLikeGeneric, quickCheckConfig1024); err != nil { 500 t.Error(err) 501 } 502 503 if !asmLikeGeneric([5]uint64{0xffffffffffffffff, 0xffffffffffffffff, 0xffffffffffffffff, 0xffffffffffffffff, 0xffffffffffffffff}) { 504 t.Errorf("failed for {0xffffffffffffffff, 0xffffffffffffffff, 0xffffffffffffffff, 0xffffffffffffffff, 0xffffffffffffffff}") 505 } 506} 507 508func TestFeSquare(t *testing.T) { 509 asmLikeGeneric := func(a Element) bool { 510 t1 := a 511 t2 := a 512 513 feSquareGeneric(&t1, &t1) 514 feSquare(&t2, &t2) 515 516 if t1 != t2 { 517 t.Logf("got: %#v,\nexpected: %#v", t1, t2) 518 } 519 520 return t1 == t2 && isInBounds(&t2) 521 } 522 523 if err := quick.Check(asmLikeGeneric, quickCheckConfig1024); err != nil { 524 t.Error(err) 525 } 526} 527 528func TestFeMul(t *testing.T) { 529 asmLikeGeneric := func(a, b Element) bool { 530 a1 := a 531 a2 := a 532 b1 := b 533 b2 := b 534 535 feMulGeneric(&a1, &a1, &b1) 536 feMul(&a2, &a2, &b2) 537 538 if a1 != a2 || b1 != b2 { 539 t.Logf("got: %#v,\nexpected: %#v", a1, a2) 540 t.Logf("got: %#v,\nexpected: %#v", b1, b2) 541 } 542 543 return a1 == a2 && isInBounds(&a2) && 544 b1 == b2 && isInBounds(&b2) 545 } 546 547 if err := quick.Check(asmLikeGeneric, quickCheckConfig1024); err != nil { 548 t.Error(err) 549 } 550} 551 552func decodeHex(s string) []byte { 553 b, err := hex.DecodeString(s) 554 if err != nil { 555 panic(err) 556 } 557 return b 558} 559