1// Copyright 2012 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package poly1305
6
7import (
8	"crypto/rand"
9	"encoding/binary"
10	"encoding/hex"
11	"flag"
12	"testing"
13	"unsafe"
14)
15
16var stressFlag = flag.Bool("stress", false, "run slow stress tests")
17
18type test struct {
19	in    string
20	key   string
21	tag   string
22	state string
23}
24
25func (t *test) Input() []byte {
26	in, err := hex.DecodeString(t.in)
27	if err != nil {
28		panic(err)
29	}
30	return in
31}
32
33func (t *test) Key() [32]byte {
34	buf, err := hex.DecodeString(t.key)
35	if err != nil {
36		panic(err)
37	}
38	var key [32]byte
39	copy(key[:], buf[:32])
40	return key
41}
42
43func (t *test) Tag() [16]byte {
44	buf, err := hex.DecodeString(t.tag)
45	if err != nil {
46		panic(err)
47	}
48	var tag [16]byte
49	copy(tag[:], buf[:16])
50	return tag
51}
52
53func (t *test) InitialState() [3]uint64 {
54	// state is hex encoded in big-endian byte order
55	if t.state == "" {
56		return [3]uint64{0, 0, 0}
57	}
58	buf, err := hex.DecodeString(t.state)
59	if err != nil {
60		panic(err)
61	}
62	if len(buf) != 3*8 {
63		panic("incorrect state length")
64	}
65	return [3]uint64{
66		binary.BigEndian.Uint64(buf[16:24]),
67		binary.BigEndian.Uint64(buf[8:16]),
68		binary.BigEndian.Uint64(buf[0:8]),
69	}
70}
71
72func testSum(t *testing.T, unaligned bool, sumImpl func(tag *[TagSize]byte, msg []byte, key *[32]byte)) {
73	var tag [16]byte
74	for i, v := range testData {
75		// cannot set initial state before calling sum, so skip those tests
76		if v.InitialState() != [3]uint64{0, 0, 0} {
77			continue
78		}
79
80		in := v.Input()
81		if unaligned {
82			in = unalignBytes(in)
83		}
84		key := v.Key()
85		sumImpl(&tag, in, &key)
86		if tag != v.Tag() {
87			t.Errorf("%d: expected %x, got %x", i, v.Tag(), tag[:])
88		}
89		if !Verify(&tag, in, &key) {
90			t.Errorf("%d: tag didn't verify", i)
91		}
92		// If the key is zero, the tag will always be zero, independent of the input.
93		if len(in) > 0 && key != [32]byte{} {
94			in[0] ^= 0xff
95			if Verify(&tag, in, &key) {
96				t.Errorf("%d: tag verified after altering the input", i)
97			}
98			in[0] ^= 0xff
99		}
100		// If the input is empty, the tag only depends on the second half of the key.
101		if len(in) > 0 {
102			key[0] ^= 0xff
103			if Verify(&tag, in, &key) {
104				t.Errorf("%d: tag verified after altering the key", i)
105			}
106			key[0] ^= 0xff
107		}
108		tag[0] ^= 0xff
109		if Verify(&tag, in, &key) {
110			t.Errorf("%d: tag verified after altering the tag", i)
111		}
112		tag[0] ^= 0xff
113	}
114}
115
116func TestBurnin(t *testing.T) {
117	// This test can be used to sanity-check significant changes. It can
118	// take about many minutes to run, even on fast machines. It's disabled
119	// by default.
120	if !*stressFlag {
121		t.Skip("skipping without -stress")
122	}
123
124	var key [32]byte
125	var input [25]byte
126	var output [16]byte
127
128	for i := range key {
129		key[i] = 1
130	}
131	for i := range input {
132		input[i] = 2
133	}
134
135	for i := uint64(0); i < 1e10; i++ {
136		Sum(&output, input[:], &key)
137		copy(key[0:], output[:])
138		copy(key[16:], output[:])
139		copy(input[:], output[:])
140		copy(input[16:], output[:])
141	}
142
143	const expected = "5e3b866aea0b636d240c83c428f84bfa"
144	if got := hex.EncodeToString(output[:]); got != expected {
145		t.Errorf("expected %s, got %s", expected, got)
146	}
147}
148
149func TestSum(t *testing.T)                 { testSum(t, false, Sum) }
150func TestSumUnaligned(t *testing.T)        { testSum(t, true, Sum) }
151func TestSumGeneric(t *testing.T)          { testSum(t, false, sumGeneric) }
152func TestSumGenericUnaligned(t *testing.T) { testSum(t, true, sumGeneric) }
153
154func TestWriteGeneric(t *testing.T)          { testWriteGeneric(t, false) }
155func TestWriteGenericUnaligned(t *testing.T) { testWriteGeneric(t, true) }
156func TestWrite(t *testing.T)                 { testWrite(t, false) }
157func TestWriteUnaligned(t *testing.T)        { testWrite(t, true) }
158
159func testWriteGeneric(t *testing.T, unaligned bool) {
160	for i, v := range testData {
161		key := v.Key()
162		input := v.Input()
163		var out [16]byte
164
165		if unaligned {
166			input = unalignBytes(input)
167		}
168		h := newMACGeneric(&key)
169		if s := v.InitialState(); s != [3]uint64{0, 0, 0} {
170			h.macState.h = s
171		}
172		n, err := h.Write(input[:len(input)/3])
173		if err != nil || n != len(input[:len(input)/3]) {
174			t.Errorf("#%d: unexpected Write results: n = %d, err = %v", i, n, err)
175		}
176		n, err = h.Write(input[len(input)/3:])
177		if err != nil || n != len(input[len(input)/3:]) {
178			t.Errorf("#%d: unexpected Write results: n = %d, err = %v", i, n, err)
179		}
180		h.Sum(&out)
181		if tag := v.Tag(); out != tag {
182			t.Errorf("%d: expected %x, got %x", i, tag[:], out[:])
183		}
184	}
185}
186
187func testWrite(t *testing.T, unaligned bool) {
188	for i, v := range testData {
189		key := v.Key()
190		input := v.Input()
191		var out [16]byte
192
193		if unaligned {
194			input = unalignBytes(input)
195		}
196		h := New(&key)
197		if s := v.InitialState(); s != [3]uint64{0, 0, 0} {
198			h.macState.h = s
199		}
200		n, err := h.Write(input[:len(input)/3])
201		if err != nil || n != len(input[:len(input)/3]) {
202			t.Errorf("#%d: unexpected Write results: n = %d, err = %v", i, n, err)
203		}
204		n, err = h.Write(input[len(input)/3:])
205		if err != nil || n != len(input[len(input)/3:]) {
206			t.Errorf("#%d: unexpected Write results: n = %d, err = %v", i, n, err)
207		}
208		h.Sum(out[:0])
209		tag := v.Tag()
210		if out != tag {
211			t.Errorf("%d: expected %x, got %x", i, tag[:], out[:])
212		}
213		if !h.Verify(tag[:]) {
214			t.Errorf("%d: Verify failed", i)
215		}
216		tag[0] ^= 0xff
217		if h.Verify(tag[:]) {
218			t.Errorf("%d: Verify succeeded after modifying the tag", i)
219		}
220	}
221}
222
223func benchmarkSum(b *testing.B, size int, unaligned bool) {
224	var out [16]byte
225	var key [32]byte
226	in := make([]byte, size)
227	if unaligned {
228		in = unalignBytes(in)
229	}
230	rand.Read(in)
231	b.SetBytes(int64(len(in)))
232	b.ResetTimer()
233	for i := 0; i < b.N; i++ {
234		Sum(&out, in, &key)
235	}
236}
237
238func benchmarkWrite(b *testing.B, size int, unaligned bool) {
239	var key [32]byte
240	h := New(&key)
241	in := make([]byte, size)
242	if unaligned {
243		in = unalignBytes(in)
244	}
245	rand.Read(in)
246	b.SetBytes(int64(len(in)))
247	b.ResetTimer()
248	for i := 0; i < b.N; i++ {
249		h.Write(in)
250	}
251}
252
253func Benchmark64(b *testing.B)          { benchmarkSum(b, 64, false) }
254func Benchmark1K(b *testing.B)          { benchmarkSum(b, 1024, false) }
255func Benchmark2M(b *testing.B)          { benchmarkSum(b, 2*1024*1024, false) }
256func Benchmark64Unaligned(b *testing.B) { benchmarkSum(b, 64, true) }
257func Benchmark1KUnaligned(b *testing.B) { benchmarkSum(b, 1024, true) }
258func Benchmark2MUnaligned(b *testing.B) { benchmarkSum(b, 2*1024*1024, true) }
259
260func BenchmarkWrite64(b *testing.B)          { benchmarkWrite(b, 64, false) }
261func BenchmarkWrite1K(b *testing.B)          { benchmarkWrite(b, 1024, false) }
262func BenchmarkWrite2M(b *testing.B)          { benchmarkWrite(b, 2*1024*1024, false) }
263func BenchmarkWrite64Unaligned(b *testing.B) { benchmarkWrite(b, 64, true) }
264func BenchmarkWrite1KUnaligned(b *testing.B) { benchmarkWrite(b, 1024, true) }
265func BenchmarkWrite2MUnaligned(b *testing.B) { benchmarkWrite(b, 2*1024*1024, true) }
266
267func unalignBytes(in []byte) []byte {
268	out := make([]byte, len(in)+1)
269	if uintptr(unsafe.Pointer(&out[0]))&(unsafe.Alignof(uint32(0))-1) == 0 {
270		out = out[1:]
271	} else {
272		out = out[:len(in)]
273	}
274	copy(out, in)
275	return out
276}
277