1// Copyright 2014 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package sha3
6
7// Tests include all the ShortMsgKATs provided by the Keccak team at
8// https://github.com/gvanas/KeccakCodePackage
9//
10// They only include the zero-bit case of the bitwise testvectors
11// published by NIST in the draft of FIPS-202.
12
13import (
14	"bytes"
15	"compress/flate"
16	"encoding/hex"
17	"encoding/json"
18	"fmt"
19	"hash"
20	"os"
21	"strings"
22	"testing"
23)
24
25const (
26	testString  = "brekeccakkeccak koax koax"
27	katFilename = "testdata/keccakKats.json.deflate"
28)
29
30// Internal-use instances of SHAKE used to test against KATs.
31func newHashShake128() hash.Hash {
32	return &state{rate: 168, dsbyte: 0x1f, outputLen: 512}
33}
34func newHashShake256() hash.Hash {
35	return &state{rate: 136, dsbyte: 0x1f, outputLen: 512}
36}
37
38// testDigests contains functions returning hash.Hash instances
39// with output-length equal to the KAT length for SHA-3, Keccak
40// and SHAKE instances.
41var testDigests = map[string]func() hash.Hash{
42	"SHA3-224":   New224,
43	"SHA3-256":   New256,
44	"SHA3-384":   New384,
45	"SHA3-512":   New512,
46	"Keccak-256": NewLegacyKeccak256,
47	"SHAKE128":   newHashShake128,
48	"SHAKE256":   newHashShake256,
49}
50
51// testShakes contains functions that return ShakeHash instances for
52// testing the ShakeHash-specific interface.
53var testShakes = map[string]func() ShakeHash{
54	"SHAKE128": NewShake128,
55	"SHAKE256": NewShake256,
56}
57
58// decodeHex converts a hex-encoded string into a raw byte string.
59func decodeHex(s string) []byte {
60	b, err := hex.DecodeString(s)
61	if err != nil {
62		panic(err)
63	}
64	return b
65}
66
67// structs used to marshal JSON test-cases.
68type KeccakKats struct {
69	Kats map[string][]struct {
70		Digest  string `json:"digest"`
71		Length  int64  `json:"length"`
72		Message string `json:"message"`
73	}
74}
75
76func testUnalignedAndGeneric(t *testing.T, testf func(impl string)) {
77	xorInOrig, copyOutOrig := xorIn, copyOut
78	xorIn, copyOut = xorInGeneric, copyOutGeneric
79	testf("generic")
80	if xorImplementationUnaligned != "generic" {
81		xorIn, copyOut = xorInUnaligned, copyOutUnaligned
82		testf("unaligned")
83	}
84	xorIn, copyOut = xorInOrig, copyOutOrig
85}
86
87// TestKeccakKats tests the SHA-3 and Shake implementations against all the
88// ShortMsgKATs from https://github.com/gvanas/KeccakCodePackage
89// (The testvectors are stored in keccakKats.json.deflate due to their length.)
90func TestKeccakKats(t *testing.T) {
91	testUnalignedAndGeneric(t, func(impl string) {
92		// Read the KATs.
93		deflated, err := os.Open(katFilename)
94		if err != nil {
95			t.Errorf("error opening %s: %s", katFilename, err)
96		}
97		file := flate.NewReader(deflated)
98		dec := json.NewDecoder(file)
99		var katSet KeccakKats
100		err = dec.Decode(&katSet)
101		if err != nil {
102			t.Errorf("error decoding KATs: %s", err)
103		}
104
105		// Do the KATs.
106		for functionName, kats := range katSet.Kats {
107			d := testDigests[functionName]()
108			for _, kat := range kats {
109				d.Reset()
110				in, err := hex.DecodeString(kat.Message)
111				if err != nil {
112					t.Errorf("error decoding KAT: %s", err)
113				}
114				d.Write(in[:kat.Length/8])
115				got := strings.ToUpper(hex.EncodeToString(d.Sum(nil)))
116				if got != kat.Digest {
117					t.Errorf("function=%s, implementation=%s, length=%d\nmessage:\n  %s\ngot:\n  %s\nwanted:\n %s",
118						functionName, impl, kat.Length, kat.Message, got, kat.Digest)
119					t.Logf("wanted %+v", kat)
120					t.FailNow()
121				}
122				continue
123			}
124		}
125	})
126}
127
128// TestKeccak does a basic test of the non-standardized Keccak hash functions.
129func TestKeccak(t *testing.T) {
130	tests := []struct {
131		fn   func() hash.Hash
132		data []byte
133		want string
134	}{
135		{
136			NewLegacyKeccak256,
137			[]byte("abc"),
138			"4e03657aea45a94fc7d47ba826c8d667c0d1e6e33a64a036ec44f58fa12d6c45",
139		},
140	}
141
142	for _, u := range tests {
143		h := u.fn()
144		h.Write(u.data)
145		got := h.Sum(nil)
146		want := decodeHex(u.want)
147		if !bytes.Equal(got, want) {
148			t.Errorf("unexpected hash for size %d: got '%x' want '%s'", h.Size()*8, got, u.want)
149		}
150	}
151}
152
153// TestUnalignedWrite tests that writing data in an arbitrary pattern with
154// small input buffers.
155func TestUnalignedWrite(t *testing.T) {
156	testUnalignedAndGeneric(t, func(impl string) {
157		buf := sequentialBytes(0x10000)
158		for alg, df := range testDigests {
159			d := df()
160			d.Reset()
161			d.Write(buf)
162			want := d.Sum(nil)
163			d.Reset()
164			for i := 0; i < len(buf); {
165				// Cycle through offsets which make a 137 byte sequence.
166				// Because 137 is prime this sequence should exercise all corner cases.
167				offsets := [17]int{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 1}
168				for _, j := range offsets {
169					if v := len(buf) - i; v < j {
170						j = v
171					}
172					d.Write(buf[i : i+j])
173					i += j
174				}
175			}
176			got := d.Sum(nil)
177			if !bytes.Equal(got, want) {
178				t.Errorf("Unaligned writes, implementation=%s, alg=%s\ngot %q, want %q", impl, alg, got, want)
179			}
180		}
181	})
182}
183
184// TestAppend checks that appending works when reallocation is necessary.
185func TestAppend(t *testing.T) {
186	testUnalignedAndGeneric(t, func(impl string) {
187		d := New224()
188
189		for capacity := 2; capacity <= 66; capacity += 64 {
190			// The first time around the loop, Sum will have to reallocate.
191			// The second time, it will not.
192			buf := make([]byte, 2, capacity)
193			d.Reset()
194			d.Write([]byte{0xcc})
195			buf = d.Sum(buf)
196			expected := "0000DF70ADC49B2E76EEE3A6931B93FA41841C3AF2CDF5B32A18B5478C39"
197			if got := strings.ToUpper(hex.EncodeToString(buf)); got != expected {
198				t.Errorf("got %s, want %s", got, expected)
199			}
200		}
201	})
202}
203
204// TestAppendNoRealloc tests that appending works when no reallocation is necessary.
205func TestAppendNoRealloc(t *testing.T) {
206	testUnalignedAndGeneric(t, func(impl string) {
207		buf := make([]byte, 1, 200)
208		d := New224()
209		d.Write([]byte{0xcc})
210		buf = d.Sum(buf)
211		expected := "00DF70ADC49B2E76EEE3A6931B93FA41841C3AF2CDF5B32A18B5478C39"
212		if got := strings.ToUpper(hex.EncodeToString(buf)); got != expected {
213			t.Errorf("%s: got %s, want %s", impl, got, expected)
214		}
215	})
216}
217
218// TestSqueezing checks that squeezing the full output a single time produces
219// the same output as repeatedly squeezing the instance.
220func TestSqueezing(t *testing.T) {
221	testUnalignedAndGeneric(t, func(impl string) {
222		for functionName, newShakeHash := range testShakes {
223			d0 := newShakeHash()
224			d0.Write([]byte(testString))
225			ref := make([]byte, 32)
226			d0.Read(ref)
227
228			d1 := newShakeHash()
229			d1.Write([]byte(testString))
230			var multiple []byte
231			for range ref {
232				one := make([]byte, 1)
233				d1.Read(one)
234				multiple = append(multiple, one...)
235			}
236			if !bytes.Equal(ref, multiple) {
237				t.Errorf("%s (%s): squeezing %d bytes one at a time failed", functionName, impl, len(ref))
238			}
239		}
240	})
241}
242
243// sequentialBytes produces a buffer of size consecutive bytes 0x00, 0x01, ..., used for testing.
244func sequentialBytes(size int) []byte {
245	result := make([]byte, size)
246	for i := range result {
247		result[i] = byte(i)
248	}
249	return result
250}
251
252// BenchmarkPermutationFunction measures the speed of the permutation function
253// with no input data.
254func BenchmarkPermutationFunction(b *testing.B) {
255	b.SetBytes(int64(200))
256	var lanes [25]uint64
257	for i := 0; i < b.N; i++ {
258		keccakF1600(&lanes)
259	}
260}
261
262// benchmarkHash tests the speed to hash num buffers of buflen each.
263func benchmarkHash(b *testing.B, h hash.Hash, size, num int) {
264	b.StopTimer()
265	h.Reset()
266	data := sequentialBytes(size)
267	b.SetBytes(int64(size * num))
268	b.StartTimer()
269
270	var state []byte
271	for i := 0; i < b.N; i++ {
272		for j := 0; j < num; j++ {
273			h.Write(data)
274		}
275		state = h.Sum(state[:0])
276	}
277	b.StopTimer()
278	h.Reset()
279}
280
281// benchmarkShake is specialized to the Shake instances, which don't
282// require a copy on reading output.
283func benchmarkShake(b *testing.B, h ShakeHash, size, num int) {
284	b.StopTimer()
285	h.Reset()
286	data := sequentialBytes(size)
287	d := make([]byte, 32)
288
289	b.SetBytes(int64(size * num))
290	b.StartTimer()
291
292	for i := 0; i < b.N; i++ {
293		h.Reset()
294		for j := 0; j < num; j++ {
295			h.Write(data)
296		}
297		h.Read(d)
298	}
299}
300
301func BenchmarkSha3_512_MTU(b *testing.B) { benchmarkHash(b, New512(), 1350, 1) }
302func BenchmarkSha3_384_MTU(b *testing.B) { benchmarkHash(b, New384(), 1350, 1) }
303func BenchmarkSha3_256_MTU(b *testing.B) { benchmarkHash(b, New256(), 1350, 1) }
304func BenchmarkSha3_224_MTU(b *testing.B) { benchmarkHash(b, New224(), 1350, 1) }
305
306func BenchmarkShake128_MTU(b *testing.B)  { benchmarkShake(b, NewShake128(), 1350, 1) }
307func BenchmarkShake256_MTU(b *testing.B)  { benchmarkShake(b, NewShake256(), 1350, 1) }
308func BenchmarkShake256_16x(b *testing.B)  { benchmarkShake(b, NewShake256(), 16, 1024) }
309func BenchmarkShake256_1MiB(b *testing.B) { benchmarkShake(b, NewShake256(), 1024, 1024) }
310
311func BenchmarkSha3_512_1MiB(b *testing.B) { benchmarkHash(b, New512(), 1024, 1024) }
312
313func Example_sum() {
314	buf := []byte("some data to hash")
315	// A hash needs to be 64 bytes long to have 256-bit collision resistance.
316	h := make([]byte, 64)
317	// Compute a 64-byte hash of buf and put it in h.
318	ShakeSum256(h, buf)
319	fmt.Printf("%x\n", h)
320	// Output: 0f65fe41fc353e52c55667bb9e2b27bfcc8476f2c413e9437d272ee3194a4e3146d05ec04a25d16b8f577c19b82d16b1424c3e022e783d2b4da98de3658d363d
321}
322
323func Example_mac() {
324	k := []byte("this is a secret key; you should generate a strong random key that's at least 32 bytes long")
325	buf := []byte("and this is some data to authenticate")
326	// A MAC with 32 bytes of output has 256-bit security strength -- if you use at least a 32-byte-long key.
327	h := make([]byte, 32)
328	d := NewShake256()
329	// Write the key into the hash.
330	d.Write(k)
331	// Now write the data.
332	d.Write(buf)
333	// Read 32 bytes of output from the hash into h.
334	d.Read(h)
335	fmt.Printf("%x\n", h)
336	// Output: 78de2974bd2711d5549ffd32b753ef0f5fa80a0db2556db60f0987eb8a9218ff
337}
338