1// Copyright 2019 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package maphash
6
7import (
8	"fmt"
9	"math"
10	"math/rand"
11	"runtime"
12	"strings"
13	"testing"
14	"unsafe"
15)
16
17// Smhasher is a torture test for hash functions.
18// https://code.google.com/p/smhasher/
19// This code is a port of some of the Smhasher tests to Go.
20
21var fixedSeed = MakeSeed()
22
23// Sanity checks.
24// hash should not depend on values outside key.
25// hash should not depend on alignment.
26func TestSmhasherSanity(t *testing.T) {
27	r := rand.New(rand.NewSource(1234))
28	const REP = 10
29	const KEYMAX = 128
30	const PAD = 16
31	const OFFMAX = 16
32	for k := 0; k < REP; k++ {
33		for n := 0; n < KEYMAX; n++ {
34			for i := 0; i < OFFMAX; i++ {
35				var b [KEYMAX + OFFMAX + 2*PAD]byte
36				var c [KEYMAX + OFFMAX + 2*PAD]byte
37				randBytes(r, b[:])
38				randBytes(r, c[:])
39				copy(c[PAD+i:PAD+i+n], b[PAD:PAD+n])
40				if bytesHash(b[PAD:PAD+n]) != bytesHash(c[PAD+i:PAD+i+n]) {
41					t.Errorf("hash depends on bytes outside key")
42				}
43			}
44		}
45	}
46}
47
48func bytesHash(b []byte) uint64 {
49	var h Hash
50	h.SetSeed(fixedSeed)
51	h.Write(b)
52	return h.Sum64()
53}
54func stringHash(s string) uint64 {
55	var h Hash
56	h.SetSeed(fixedSeed)
57	h.WriteString(s)
58	return h.Sum64()
59}
60
61const hashSize = 64
62
63func randBytes(r *rand.Rand, b []byte) {
64	r.Read(b) // can't fail
65}
66
67// A hashSet measures the frequency of hash collisions.
68type hashSet struct {
69	m map[uint64]struct{} // set of hashes added
70	n int                 // number of hashes added
71}
72
73func newHashSet() *hashSet {
74	return &hashSet{make(map[uint64]struct{}), 0}
75}
76func (s *hashSet) add(h uint64) {
77	s.m[h] = struct{}{}
78	s.n++
79}
80func (s *hashSet) addS(x string) {
81	s.add(stringHash(x))
82}
83func (s *hashSet) addB(x []byte) {
84	s.add(bytesHash(x))
85}
86func (s *hashSet) addS_seed(x string, seed Seed) {
87	var h Hash
88	h.SetSeed(seed)
89	h.WriteString(x)
90	s.add(h.Sum64())
91}
92func (s *hashSet) check(t *testing.T) {
93	const SLOP = 10.0
94	collisions := s.n - len(s.m)
95	pairs := int64(s.n) * int64(s.n-1) / 2
96	expected := float64(pairs) / math.Pow(2.0, float64(hashSize))
97	stddev := math.Sqrt(expected)
98	if float64(collisions) > expected+SLOP*(3*stddev+1) {
99		t.Errorf("unexpected number of collisions: got=%d mean=%f stddev=%f", collisions, expected, stddev)
100	}
101}
102
103// a string plus adding zeros must make distinct hashes
104func TestSmhasherAppendedZeros(t *testing.T) {
105	s := "hello" + strings.Repeat("\x00", 256)
106	h := newHashSet()
107	for i := 0; i <= len(s); i++ {
108		h.addS(s[:i])
109	}
110	h.check(t)
111}
112
113// All 0-3 byte strings have distinct hashes.
114func TestSmhasherSmallKeys(t *testing.T) {
115	h := newHashSet()
116	var b [3]byte
117	for i := 0; i < 256; i++ {
118		b[0] = byte(i)
119		h.addB(b[:1])
120		for j := 0; j < 256; j++ {
121			b[1] = byte(j)
122			h.addB(b[:2])
123			if !testing.Short() {
124				for k := 0; k < 256; k++ {
125					b[2] = byte(k)
126					h.addB(b[:3])
127				}
128			}
129		}
130	}
131	h.check(t)
132}
133
134// Different length strings of all zeros have distinct hashes.
135func TestSmhasherZeros(t *testing.T) {
136	N := 256 * 1024
137	if testing.Short() {
138		N = 1024
139	}
140	h := newHashSet()
141	b := make([]byte, N)
142	for i := 0; i <= N; i++ {
143		h.addB(b[:i])
144	}
145	h.check(t)
146}
147
148// Strings with up to two nonzero bytes all have distinct hashes.
149func TestSmhasherTwoNonzero(t *testing.T) {
150	if runtime.GOARCH == "wasm" {
151		t.Skip("Too slow on wasm")
152	}
153	if testing.Short() {
154		t.Skip("Skipping in short mode")
155	}
156	h := newHashSet()
157	for n := 2; n <= 16; n++ {
158		twoNonZero(h, n)
159	}
160	h.check(t)
161}
162func twoNonZero(h *hashSet, n int) {
163	b := make([]byte, n)
164
165	// all zero
166	h.addB(b)
167
168	// one non-zero byte
169	for i := 0; i < n; i++ {
170		for x := 1; x < 256; x++ {
171			b[i] = byte(x)
172			h.addB(b)
173			b[i] = 0
174		}
175	}
176
177	// two non-zero bytes
178	for i := 0; i < n; i++ {
179		for x := 1; x < 256; x++ {
180			b[i] = byte(x)
181			for j := i + 1; j < n; j++ {
182				for y := 1; y < 256; y++ {
183					b[j] = byte(y)
184					h.addB(b)
185					b[j] = 0
186				}
187			}
188			b[i] = 0
189		}
190	}
191}
192
193// Test strings with repeats, like "abcdabcdabcdabcd..."
194func TestSmhasherCyclic(t *testing.T) {
195	if testing.Short() {
196		t.Skip("Skipping in short mode")
197	}
198	r := rand.New(rand.NewSource(1234))
199	const REPEAT = 8
200	const N = 1000000
201	for n := 4; n <= 12; n++ {
202		h := newHashSet()
203		b := make([]byte, REPEAT*n)
204		for i := 0; i < N; i++ {
205			b[0] = byte(i * 79 % 97)
206			b[1] = byte(i * 43 % 137)
207			b[2] = byte(i * 151 % 197)
208			b[3] = byte(i * 199 % 251)
209			randBytes(r, b[4:n])
210			for j := n; j < n*REPEAT; j++ {
211				b[j] = b[j-n]
212			}
213			h.addB(b)
214		}
215		h.check(t)
216	}
217}
218
219// Test strings with only a few bits set
220func TestSmhasherSparse(t *testing.T) {
221	if runtime.GOARCH == "wasm" {
222		t.Skip("Too slow on wasm")
223	}
224	if testing.Short() {
225		t.Skip("Skipping in short mode")
226	}
227	sparse(t, 32, 6)
228	sparse(t, 40, 6)
229	sparse(t, 48, 5)
230	sparse(t, 56, 5)
231	sparse(t, 64, 5)
232	sparse(t, 96, 4)
233	sparse(t, 256, 3)
234	sparse(t, 2048, 2)
235}
236func sparse(t *testing.T, n int, k int) {
237	b := make([]byte, n/8)
238	h := newHashSet()
239	setbits(h, b, 0, k)
240	h.check(t)
241}
242
243// set up to k bits at index i and greater
244func setbits(h *hashSet, b []byte, i int, k int) {
245	h.addB(b)
246	if k == 0 {
247		return
248	}
249	for j := i; j < len(b)*8; j++ {
250		b[j/8] |= byte(1 << uint(j&7))
251		setbits(h, b, j+1, k-1)
252		b[j/8] &= byte(^(1 << uint(j&7)))
253	}
254}
255
256// Test all possible combinations of n blocks from the set s.
257// "permutation" is a bad name here, but it is what Smhasher uses.
258func TestSmhasherPermutation(t *testing.T) {
259	if runtime.GOARCH == "wasm" {
260		t.Skip("Too slow on wasm")
261	}
262	if testing.Short() {
263		t.Skip("Skipping in short mode")
264	}
265	permutation(t, []uint32{0, 1, 2, 3, 4, 5, 6, 7}, 8)
266	permutation(t, []uint32{0, 1 << 29, 2 << 29, 3 << 29, 4 << 29, 5 << 29, 6 << 29, 7 << 29}, 8)
267	permutation(t, []uint32{0, 1}, 20)
268	permutation(t, []uint32{0, 1 << 31}, 20)
269	permutation(t, []uint32{0, 1, 2, 3, 4, 5, 6, 7, 1 << 29, 2 << 29, 3 << 29, 4 << 29, 5 << 29, 6 << 29, 7 << 29}, 6)
270}
271func permutation(t *testing.T, s []uint32, n int) {
272	b := make([]byte, n*4)
273	h := newHashSet()
274	genPerm(h, b, s, 0)
275	h.check(t)
276}
277func genPerm(h *hashSet, b []byte, s []uint32, n int) {
278	h.addB(b[:n])
279	if n == len(b) {
280		return
281	}
282	for _, v := range s {
283		b[n] = byte(v)
284		b[n+1] = byte(v >> 8)
285		b[n+2] = byte(v >> 16)
286		b[n+3] = byte(v >> 24)
287		genPerm(h, b, s, n+4)
288	}
289}
290
291type key interface {
292	clear()              // set bits all to 0
293	random(r *rand.Rand) // set key to something random
294	bits() int           // how many bits key has
295	flipBit(i int)       // flip bit i of the key
296	hash() uint64        // hash the key
297	name() string        // for error reporting
298}
299
300type bytesKey struct {
301	b []byte
302}
303
304func (k *bytesKey) clear() {
305	for i := range k.b {
306		k.b[i] = 0
307	}
308}
309func (k *bytesKey) random(r *rand.Rand) {
310	randBytes(r, k.b)
311}
312func (k *bytesKey) bits() int {
313	return len(k.b) * 8
314}
315func (k *bytesKey) flipBit(i int) {
316	k.b[i>>3] ^= byte(1 << uint(i&7))
317}
318func (k *bytesKey) hash() uint64 {
319	return bytesHash(k.b)
320}
321func (k *bytesKey) name() string {
322	return fmt.Sprintf("bytes%d", len(k.b))
323}
324
325// Flipping a single bit of a key should flip each output bit with 50% probability.
326func TestSmhasherAvalanche(t *testing.T) {
327	if runtime.GOARCH == "wasm" {
328		t.Skip("Too slow on wasm")
329	}
330	if testing.Short() {
331		t.Skip("Skipping in short mode")
332	}
333	avalancheTest1(t, &bytesKey{make([]byte, 2)})
334	avalancheTest1(t, &bytesKey{make([]byte, 4)})
335	avalancheTest1(t, &bytesKey{make([]byte, 8)})
336	avalancheTest1(t, &bytesKey{make([]byte, 16)})
337	avalancheTest1(t, &bytesKey{make([]byte, 32)})
338	avalancheTest1(t, &bytesKey{make([]byte, 200)})
339}
340func avalancheTest1(t *testing.T, k key) {
341	const REP = 100000
342	r := rand.New(rand.NewSource(1234))
343	n := k.bits()
344
345	// grid[i][j] is a count of whether flipping
346	// input bit i affects output bit j.
347	grid := make([][hashSize]int, n)
348
349	for z := 0; z < REP; z++ {
350		// pick a random key, hash it
351		k.random(r)
352		h := k.hash()
353
354		// flip each bit, hash & compare the results
355		for i := 0; i < n; i++ {
356			k.flipBit(i)
357			d := h ^ k.hash()
358			k.flipBit(i)
359
360			// record the effects of that bit flip
361			g := &grid[i]
362			for j := 0; j < hashSize; j++ {
363				g[j] += int(d & 1)
364				d >>= 1
365			}
366		}
367	}
368
369	// Each entry in the grid should be about REP/2.
370	// More precisely, we did N = k.bits() * hashSize experiments where
371	// each is the sum of REP coin flips. We want to find bounds on the
372	// sum of coin flips such that a truly random experiment would have
373	// all sums inside those bounds with 99% probability.
374	N := n * hashSize
375	var c float64
376	// find c such that Prob(mean-c*stddev < x < mean+c*stddev)^N > .9999
377	for c = 0.0; math.Pow(math.Erf(c/math.Sqrt(2)), float64(N)) < .9999; c += .1 {
378	}
379	c *= 4.0 // allowed slack - we don't need to be perfectly random
380	mean := .5 * REP
381	stddev := .5 * math.Sqrt(REP)
382	low := int(mean - c*stddev)
383	high := int(mean + c*stddev)
384	for i := 0; i < n; i++ {
385		for j := 0; j < hashSize; j++ {
386			x := grid[i][j]
387			if x < low || x > high {
388				t.Errorf("bad bias for %s bit %d -> bit %d: %d/%d\n", k.name(), i, j, x, REP)
389			}
390		}
391	}
392}
393
394// All bit rotations of a set of distinct keys
395func TestSmhasherWindowed(t *testing.T) {
396	windowed(t, &bytesKey{make([]byte, 128)})
397}
398func windowed(t *testing.T, k key) {
399	if runtime.GOARCH == "wasm" {
400		t.Skip("Too slow on wasm")
401	}
402	if testing.Short() {
403		t.Skip("Skipping in short mode")
404	}
405	const BITS = 16
406
407	for r := 0; r < k.bits(); r++ {
408		h := newHashSet()
409		for i := 0; i < 1<<BITS; i++ {
410			k.clear()
411			for j := 0; j < BITS; j++ {
412				if i>>uint(j)&1 != 0 {
413					k.flipBit((j + r) % k.bits())
414				}
415			}
416			h.add(k.hash())
417		}
418		h.check(t)
419	}
420}
421
422// All keys of the form prefix + [A-Za-z0-9]*N + suffix.
423func TestSmhasherText(t *testing.T) {
424	if testing.Short() {
425		t.Skip("Skipping in short mode")
426	}
427	text(t, "Foo", "Bar")
428	text(t, "FooBar", "")
429	text(t, "", "FooBar")
430}
431func text(t *testing.T, prefix, suffix string) {
432	const N = 4
433	const S = "ABCDEFGHIJKLMNOPQRSTabcdefghijklmnopqrst0123456789"
434	const L = len(S)
435	b := make([]byte, len(prefix)+N+len(suffix))
436	copy(b, prefix)
437	copy(b[len(prefix)+N:], suffix)
438	h := newHashSet()
439	c := b[len(prefix):]
440	for i := 0; i < L; i++ {
441		c[0] = S[i]
442		for j := 0; j < L; j++ {
443			c[1] = S[j]
444			for k := 0; k < L; k++ {
445				c[2] = S[k]
446				for x := 0; x < L; x++ {
447					c[3] = S[x]
448					h.addB(b)
449				}
450			}
451		}
452	}
453	h.check(t)
454}
455
456// Make sure different seed values generate different hashes.
457func TestSmhasherSeed(t *testing.T) {
458	if unsafe.Sizeof(uintptr(0)) == 4 {
459		t.Skip("32-bit platforms don't have ideal seed-input distributions (see issue 33988)")
460	}
461	h := newHashSet()
462	const N = 100000
463	s := "hello"
464	for i := 0; i < N; i++ {
465		h.addS_seed(s, Seed{s: uint64(i + 1)})
466		h.addS_seed(s, Seed{s: uint64(i+1) << 32}) // make sure high bits are used
467	}
468	h.check(t)
469}
470