1// Copyright (c) 2017 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package field
6
7import (
8	"bytes"
9	"crypto/rand"
10	"encoding/hex"
11	"io"
12	"math/big"
13	"math/bits"
14	mathrand "math/rand"
15	"reflect"
16	"testing"
17	"testing/quick"
18)
19
20func (v Element) String() string {
21	return hex.EncodeToString(v.Bytes())
22}
23
24// quickCheckConfig1024 will make each quickcheck test run (1024 * -quickchecks)
25// times. The default value of -quickchecks is 100.
26var quickCheckConfig1024 = &quick.Config{MaxCountScale: 1 << 10}
27
28func generateFieldElement(rand *mathrand.Rand) Element {
29	const maskLow52Bits = (1 << 52) - 1
30	return Element{
31		rand.Uint64() & maskLow52Bits,
32		rand.Uint64() & maskLow52Bits,
33		rand.Uint64() & maskLow52Bits,
34		rand.Uint64() & maskLow52Bits,
35		rand.Uint64() & maskLow52Bits,
36	}
37}
38
39// weirdLimbs can be combined to generate a range of edge-case field elements.
40// 0 and -1 are intentionally more weighted, as they combine well.
41var (
42	weirdLimbs51 = []uint64{
43		0, 0, 0, 0,
44		1,
45		19 - 1,
46		19,
47		0x2aaaaaaaaaaaa,
48		0x5555555555555,
49		(1 << 51) - 20,
50		(1 << 51) - 19,
51		(1 << 51) - 1, (1 << 51) - 1,
52		(1 << 51) - 1, (1 << 51) - 1,
53	}
54	weirdLimbs52 = []uint64{
55		0, 0, 0, 0, 0, 0,
56		1,
57		19 - 1,
58		19,
59		0x2aaaaaaaaaaaa,
60		0x5555555555555,
61		(1 << 51) - 20,
62		(1 << 51) - 19,
63		(1 << 51) - 1, (1 << 51) - 1,
64		(1 << 51) - 1, (1 << 51) - 1,
65		(1 << 51) - 1, (1 << 51) - 1,
66		1 << 51,
67		(1 << 51) + 1,
68		(1 << 52) - 19,
69		(1 << 52) - 1,
70	}
71)
72
73func generateWeirdFieldElement(rand *mathrand.Rand) Element {
74	return Element{
75		weirdLimbs52[rand.Intn(len(weirdLimbs52))],
76		weirdLimbs51[rand.Intn(len(weirdLimbs51))],
77		weirdLimbs51[rand.Intn(len(weirdLimbs51))],
78		weirdLimbs51[rand.Intn(len(weirdLimbs51))],
79		weirdLimbs51[rand.Intn(len(weirdLimbs51))],
80	}
81}
82
83func (Element) Generate(rand *mathrand.Rand, size int) reflect.Value {
84	if rand.Intn(2) == 0 {
85		return reflect.ValueOf(generateWeirdFieldElement(rand))
86	}
87	return reflect.ValueOf(generateFieldElement(rand))
88}
89
90// isInBounds returns whether the element is within the expected bit size bounds
91// after a light reduction.
92func isInBounds(x *Element) bool {
93	return bits.Len64(x.l0) <= 52 &&
94		bits.Len64(x.l1) <= 52 &&
95		bits.Len64(x.l2) <= 52 &&
96		bits.Len64(x.l3) <= 52 &&
97		bits.Len64(x.l4) <= 52
98}
99
100func TestMultiplyDistributesOverAdd(t *testing.T) {
101	multiplyDistributesOverAdd := func(x, y, z Element) bool {
102		// Compute t1 = (x+y)*z
103		t1 := new(Element)
104		t1.Add(&x, &y)
105		t1.Multiply(t1, &z)
106
107		// Compute t2 = x*z + y*z
108		t2 := new(Element)
109		t3 := new(Element)
110		t2.Multiply(&x, &z)
111		t3.Multiply(&y, &z)
112		t2.Add(t2, t3)
113
114		return t1.Equal(t2) == 1 && isInBounds(t1) && isInBounds(t2)
115	}
116
117	if err := quick.Check(multiplyDistributesOverAdd, quickCheckConfig1024); err != nil {
118		t.Error(err)
119	}
120}
121
122func TestMul64to128(t *testing.T) {
123	a := uint64(5)
124	b := uint64(5)
125	r := mul64(a, b)
126	if r.lo != 0x19 || r.hi != 0 {
127		t.Errorf("lo-range wide mult failed, got %d + %d*(2**64)", r.lo, r.hi)
128	}
129
130	a = uint64(18014398509481983) // 2^54 - 1
131	b = uint64(18014398509481983) // 2^54 - 1
132	r = mul64(a, b)
133	if r.lo != 0xff80000000000001 || r.hi != 0xfffffffffff {
134		t.Errorf("hi-range wide mult failed, got %d + %d*(2**64)", r.lo, r.hi)
135	}
136
137	a = uint64(1125899906842661)
138	b = uint64(2097155)
139	r = mul64(a, b)
140	r = addMul64(r, a, b)
141	r = addMul64(r, a, b)
142	r = addMul64(r, a, b)
143	r = addMul64(r, a, b)
144	if r.lo != 16888498990613035 || r.hi != 640 {
145		t.Errorf("wrong answer: %d + %d*(2**64)", r.lo, r.hi)
146	}
147}
148
149func TestSetBytesRoundTrip(t *testing.T) {
150	f1 := func(in [32]byte, fe Element) bool {
151		fe.SetBytes(in[:])
152
153		// Mask the most significant bit as it's ignored by SetBytes. (Now
154		// instead of earlier so we check the masking in SetBytes is working.)
155		in[len(in)-1] &= (1 << 7) - 1
156
157		return bytes.Equal(in[:], fe.Bytes()) && isInBounds(&fe)
158	}
159	if err := quick.Check(f1, nil); err != nil {
160		t.Errorf("failed bytes->FE->bytes round-trip: %v", err)
161	}
162
163	f2 := func(fe, r Element) bool {
164		r.SetBytes(fe.Bytes())
165
166		// Intentionally not using Equal not to go through Bytes again.
167		// Calling reduce because both Generate and SetBytes can produce
168		// non-canonical representations.
169		fe.reduce()
170		r.reduce()
171		return fe == r
172	}
173	if err := quick.Check(f2, nil); err != nil {
174		t.Errorf("failed FE->bytes->FE round-trip: %v", err)
175	}
176
177	// Check some fixed vectors from dalek
178	type feRTTest struct {
179		fe Element
180		b  []byte
181	}
182	var tests = []feRTTest{
183		{
184			fe: Element{358744748052810, 1691584618240980, 977650209285361, 1429865912637724, 560044844278676},
185			b:  []byte{74, 209, 69, 197, 70, 70, 161, 222, 56, 226, 229, 19, 112, 60, 25, 92, 187, 74, 222, 56, 50, 153, 51, 233, 40, 74, 57, 6, 160, 185, 213, 31},
186		},
187		{
188			fe: Element{84926274344903, 473620666599931, 365590438845504, 1028470286882429, 2146499180330972},
189			b:  []byte{199, 23, 106, 112, 61, 77, 216, 79, 186, 60, 11, 118, 13, 16, 103, 15, 42, 32, 83, 250, 44, 57, 204, 198, 78, 199, 253, 119, 146, 172, 3, 122},
190		},
191	}
192
193	for _, tt := range tests {
194		b := tt.fe.Bytes()
195		if !bytes.Equal(b, tt.b) || new(Element).SetBytes(tt.b).Equal(&tt.fe) != 1 {
196			t.Errorf("Failed fixed roundtrip: %v", tt)
197		}
198	}
199}
200
201func swapEndianness(buf []byte) []byte {
202	for i := 0; i < len(buf)/2; i++ {
203		buf[i], buf[len(buf)-i-1] = buf[len(buf)-i-1], buf[i]
204	}
205	return buf
206}
207
208func TestBytesBigEquivalence(t *testing.T) {
209	f1 := func(in [32]byte, fe, fe1 Element) bool {
210		fe.SetBytes(in[:])
211
212		in[len(in)-1] &= (1 << 7) - 1 // mask the most significant bit
213		b := new(big.Int).SetBytes(swapEndianness(in[:]))
214		fe1.fromBig(b)
215
216		if fe != fe1 {
217			return false
218		}
219
220		buf := make([]byte, 32) // pad with zeroes
221		copy(buf, swapEndianness(fe1.toBig().Bytes()))
222
223		return bytes.Equal(fe.Bytes(), buf) && isInBounds(&fe) && isInBounds(&fe1)
224	}
225	if err := quick.Check(f1, nil); err != nil {
226		t.Error(err)
227	}
228}
229
230// fromBig sets v = n, and returns v. The bit length of n must not exceed 256.
231func (v *Element) fromBig(n *big.Int) *Element {
232	if n.BitLen() > 32*8 {
233		panic("edwards25519: invalid field element input size")
234	}
235
236	buf := make([]byte, 0, 32)
237	for _, word := range n.Bits() {
238		for i := 0; i < bits.UintSize; i += 8 {
239			if len(buf) >= cap(buf) {
240				break
241			}
242			buf = append(buf, byte(word))
243			word >>= 8
244		}
245	}
246
247	return v.SetBytes(buf[:32])
248}
249
250func (v *Element) fromDecimal(s string) *Element {
251	n, ok := new(big.Int).SetString(s, 10)
252	if !ok {
253		panic("not a valid decimal: " + s)
254	}
255	return v.fromBig(n)
256}
257
258// toBig returns v as a big.Int.
259func (v *Element) toBig() *big.Int {
260	buf := v.Bytes()
261
262	words := make([]big.Word, 32*8/bits.UintSize)
263	for n := range words {
264		for i := 0; i < bits.UintSize; i += 8 {
265			if len(buf) == 0 {
266				break
267			}
268			words[n] |= big.Word(buf[0]) << big.Word(i)
269			buf = buf[1:]
270		}
271	}
272
273	return new(big.Int).SetBits(words)
274}
275
276func TestDecimalConstants(t *testing.T) {
277	sqrtM1String := "19681161376707505956807079304988542015446066515923890162744021073123829784752"
278	if exp := new(Element).fromDecimal(sqrtM1String); sqrtM1.Equal(exp) != 1 {
279		t.Errorf("sqrtM1 is %v, expected %v", sqrtM1, exp)
280	}
281	// d is in the parent package, and we don't want to expose d or fromDecimal.
282	// dString := "37095705934669439343138083508754565189542113879843219016388785533085940283555"
283	// if exp := new(Element).fromDecimal(dString); d.Equal(exp) != 1 {
284	// 	t.Errorf("d is %v, expected %v", d, exp)
285	// }
286}
287
288func TestSetBytesRoundTripEdgeCases(t *testing.T) {
289	// TODO: values close to 0, close to 2^255-19, between 2^255-19 and 2^255-1,
290	// and between 2^255 and 2^256-1. Test both the documented SetBytes
291	// behavior, and that Bytes reduces them.
292}
293
294// Tests self-consistency between Multiply and Square.
295func TestConsistency(t *testing.T) {
296	var x Element
297	var x2, x2sq Element
298
299	x = Element{1, 1, 1, 1, 1}
300	x2.Multiply(&x, &x)
301	x2sq.Square(&x)
302
303	if x2 != x2sq {
304		t.Fatalf("all ones failed\nmul: %x\nsqr: %x\n", x2, x2sq)
305	}
306
307	var bytes [32]byte
308
309	_, err := io.ReadFull(rand.Reader, bytes[:])
310	if err != nil {
311		t.Fatal(err)
312	}
313	x.SetBytes(bytes[:])
314
315	x2.Multiply(&x, &x)
316	x2sq.Square(&x)
317
318	if x2 != x2sq {
319		t.Fatalf("all ones failed\nmul: %x\nsqr: %x\n", x2, x2sq)
320	}
321}
322
323func TestEqual(t *testing.T) {
324	x := Element{1, 1, 1, 1, 1}
325	y := Element{5, 4, 3, 2, 1}
326
327	eq := x.Equal(&x)
328	if eq != 1 {
329		t.Errorf("wrong about equality")
330	}
331
332	eq = x.Equal(&y)
333	if eq != 0 {
334		t.Errorf("wrong about inequality")
335	}
336}
337
338func TestInvert(t *testing.T) {
339	x := Element{1, 1, 1, 1, 1}
340	one := Element{1, 0, 0, 0, 0}
341	var xinv, r Element
342
343	xinv.Invert(&x)
344	r.Multiply(&x, &xinv)
345	r.reduce()
346
347	if one != r {
348		t.Errorf("inversion identity failed, got: %x", r)
349	}
350
351	var bytes [32]byte
352
353	_, err := io.ReadFull(rand.Reader, bytes[:])
354	if err != nil {
355		t.Fatal(err)
356	}
357	x.SetBytes(bytes[:])
358
359	xinv.Invert(&x)
360	r.Multiply(&x, &xinv)
361	r.reduce()
362
363	if one != r {
364		t.Errorf("random inversion identity failed, got: %x for field element %x", r, x)
365	}
366
367	zero := Element{}
368	x.Set(&zero)
369	if xx := xinv.Invert(&x); xx != &xinv {
370		t.Errorf("inverting zero did not return the receiver")
371	} else if xinv.Equal(&zero) != 1 {
372		t.Errorf("inverting zero did not return zero")
373	}
374}
375
376func TestSelectSwap(t *testing.T) {
377	a := Element{358744748052810, 1691584618240980, 977650209285361, 1429865912637724, 560044844278676}
378	b := Element{84926274344903, 473620666599931, 365590438845504, 1028470286882429, 2146499180330972}
379
380	var c, d Element
381
382	c.Select(&a, &b, 1)
383	d.Select(&a, &b, 0)
384
385	if c.Equal(&a) != 1 || d.Equal(&b) != 1 {
386		t.Errorf("Select failed")
387	}
388
389	c.Swap(&d, 0)
390
391	if c.Equal(&a) != 1 || d.Equal(&b) != 1 {
392		t.Errorf("Swap failed")
393	}
394
395	c.Swap(&d, 1)
396
397	if c.Equal(&b) != 1 || d.Equal(&a) != 1 {
398		t.Errorf("Swap failed")
399	}
400}
401
402func TestMult32(t *testing.T) {
403	mult32EquivalentToMul := func(x Element, y uint32) bool {
404		t1 := new(Element)
405		for i := 0; i < 100; i++ {
406			t1.Mult32(&x, y)
407		}
408
409		ty := new(Element)
410		ty.l0 = uint64(y)
411
412		t2 := new(Element)
413		for i := 0; i < 100; i++ {
414			t2.Multiply(&x, ty)
415		}
416
417		return t1.Equal(t2) == 1 && isInBounds(t1) && isInBounds(t2)
418	}
419
420	if err := quick.Check(mult32EquivalentToMul, quickCheckConfig1024); err != nil {
421		t.Error(err)
422	}
423}
424
425func TestSqrtRatio(t *testing.T) {
426	// From draft-irtf-cfrg-ristretto255-decaf448-00, Appendix A.4.
427	type test struct {
428		u, v      string
429		wasSquare int
430		r         string
431	}
432	var tests = []test{
433		// If u is 0, the function is defined to return (0, TRUE), even if v
434		// is zero. Note that where used in this package, the denominator v
435		// is never zero.
436		{
437			"0000000000000000000000000000000000000000000000000000000000000000",
438			"0000000000000000000000000000000000000000000000000000000000000000",
439			1, "0000000000000000000000000000000000000000000000000000000000000000",
440		},
441		// 0/1 == 0²
442		{
443			"0000000000000000000000000000000000000000000000000000000000000000",
444			"0100000000000000000000000000000000000000000000000000000000000000",
445			1, "0000000000000000000000000000000000000000000000000000000000000000",
446		},
447		// If u is non-zero and v is zero, defined to return (0, FALSE).
448		{
449			"0100000000000000000000000000000000000000000000000000000000000000",
450			"0000000000000000000000000000000000000000000000000000000000000000",
451			0, "0000000000000000000000000000000000000000000000000000000000000000",
452		},
453		// 2/1 is not square in this field.
454		{
455			"0200000000000000000000000000000000000000000000000000000000000000",
456			"0100000000000000000000000000000000000000000000000000000000000000",
457			0, "3c5ff1b5d8e4113b871bd052f9e7bcd0582804c266ffb2d4f4203eb07fdb7c54",
458		},
459		// 4/1 == 2²
460		{
461			"0400000000000000000000000000000000000000000000000000000000000000",
462			"0100000000000000000000000000000000000000000000000000000000000000",
463			1, "0200000000000000000000000000000000000000000000000000000000000000",
464		},
465		// 1/4 == (2⁻¹)² == (2^(p-2))² per Euler's theorem
466		{
467			"0100000000000000000000000000000000000000000000000000000000000000",
468			"0400000000000000000000000000000000000000000000000000000000000000",
469			1, "f6ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff3f",
470		},
471	}
472
473	for i, tt := range tests {
474		u := new(Element).SetBytes(decodeHex(tt.u))
475		v := new(Element).SetBytes(decodeHex(tt.v))
476		want := new(Element).SetBytes(decodeHex(tt.r))
477		got, wasSquare := new(Element).SqrtRatio(u, v)
478		if got.Equal(want) == 0 || wasSquare != tt.wasSquare {
479			t.Errorf("%d: got (%v, %v), want (%v, %v)", i, got, wasSquare, want, tt.wasSquare)
480		}
481	}
482}
483
484func TestCarryPropagate(t *testing.T) {
485	asmLikeGeneric := func(a [5]uint64) bool {
486		t1 := &Element{a[0], a[1], a[2], a[3], a[4]}
487		t2 := &Element{a[0], a[1], a[2], a[3], a[4]}
488
489		t1.carryPropagate()
490		t2.carryPropagateGeneric()
491
492		if *t1 != *t2 {
493			t.Logf("got: %#v,\nexpected: %#v", t1, t2)
494		}
495
496		return *t1 == *t2 && isInBounds(t2)
497	}
498
499	if err := quick.Check(asmLikeGeneric, quickCheckConfig1024); err != nil {
500		t.Error(err)
501	}
502
503	if !asmLikeGeneric([5]uint64{0xffffffffffffffff, 0xffffffffffffffff, 0xffffffffffffffff, 0xffffffffffffffff, 0xffffffffffffffff}) {
504		t.Errorf("failed for {0xffffffffffffffff, 0xffffffffffffffff, 0xffffffffffffffff, 0xffffffffffffffff, 0xffffffffffffffff}")
505	}
506}
507
508func TestFeSquare(t *testing.T) {
509	asmLikeGeneric := func(a Element) bool {
510		t1 := a
511		t2 := a
512
513		feSquareGeneric(&t1, &t1)
514		feSquare(&t2, &t2)
515
516		if t1 != t2 {
517			t.Logf("got: %#v,\nexpected: %#v", t1, t2)
518		}
519
520		return t1 == t2 && isInBounds(&t2)
521	}
522
523	if err := quick.Check(asmLikeGeneric, quickCheckConfig1024); err != nil {
524		t.Error(err)
525	}
526}
527
528func TestFeMul(t *testing.T) {
529	asmLikeGeneric := func(a, b Element) bool {
530		a1 := a
531		a2 := a
532		b1 := b
533		b2 := b
534
535		feMulGeneric(&a1, &a1, &b1)
536		feMul(&a2, &a2, &b2)
537
538		if a1 != a2 || b1 != b2 {
539			t.Logf("got: %#v,\nexpected: %#v", a1, a2)
540			t.Logf("got: %#v,\nexpected: %#v", b1, b2)
541		}
542
543		return a1 == a2 && isInBounds(&a2) &&
544			b1 == b2 && isInBounds(&b2)
545	}
546
547	if err := quick.Check(asmLikeGeneric, quickCheckConfig1024); err != nil {
548		t.Error(err)
549	}
550}
551
552func decodeHex(s string) []byte {
553	b, err := hex.DecodeString(s)
554	if err != nil {
555		panic(err)
556	}
557	return b
558}
559