1// Copyright 2019+ Klaus Post. All rights reserved.
2// License information can be found in the LICENSE file.
3// Based on work by Yann Collet, released under BSD License.
4
5package zstd
6
7import (
8	"bufio"
9	"bytes"
10	"encoding/binary"
11	"encoding/hex"
12	"fmt"
13	"io"
14	"io/ioutil"
15	"log"
16	"math/rand"
17	"os"
18	"path/filepath"
19	"reflect"
20	"runtime"
21	"strings"
22	"sync"
23	"testing"
24	"time"
25
26	// "github.com/DataDog/zstd"
27	// zstd "github.com/valyala/gozstd"
28
29	"github.com/klauspost/compress/zip"
30	"github.com/klauspost/compress/zstd/internal/xxhash"
31)
32
33func TestNewReaderMismatch(t *testing.T) {
34	// To identify a potential decoding error, do the following steps:
35	// 1) Place the compressed file in testdata, eg 'testdata/backup.bin.zst'
36	// 2) Decompress the file to using zstd, so it will be named 'testdata/backup.bin'
37	// 3) Run the test. A hash file will be generated 'testdata/backup.bin.hash'
38	// 4) The decoder will also run and decode the file. It will stop as soon as a mismatch is found.
39	// The hash file will be reused between runs if present.
40	const baseFile = "testdata/backup.bin"
41	const blockSize = 1024
42	hashes, err := ioutil.ReadFile(baseFile + ".hash")
43	if os.IsNotExist(err) {
44		// Create the hash file.
45		f, err := os.Open(baseFile)
46		if os.IsNotExist(err) {
47			t.Skip("no decompressed file found")
48			return
49		}
50		defer f.Close()
51		br := bufio.NewReader(f)
52		var tmp [8]byte
53		xx := xxhash.New()
54		for {
55			xx.Reset()
56			buf := make([]byte, blockSize)
57			n, err := io.ReadFull(br, buf)
58			if err != nil {
59				if err != io.EOF && err != io.ErrUnexpectedEOF {
60					t.Fatal(err)
61				}
62			}
63			if n > 0 {
64				_, _ = xx.Write(buf[:n])
65				binary.LittleEndian.PutUint64(tmp[:], xx.Sum64())
66				hashes = append(hashes, tmp[4:]...)
67			}
68			if n != blockSize {
69				break
70			}
71		}
72		err = ioutil.WriteFile(baseFile+".hash", hashes, os.ModePerm)
73		if err != nil {
74			// We can continue for now
75			t.Error(err)
76		}
77		t.Log("Saved", len(hashes)/4, "hashes as", baseFile+".hash")
78	}
79
80	f, err := os.Open(baseFile + ".zst")
81	if os.IsNotExist(err) {
82		t.Skip("no compressed file found")
83		return
84	}
85	defer f.Close()
86	dec, err := NewReader(f, WithDecoderConcurrency(1))
87	if err != nil {
88		t.Fatal(err)
89	}
90	defer dec.Close()
91	var tmp [8]byte
92	xx := xxhash.New()
93	var cHash int
94	for {
95		xx.Reset()
96		buf := make([]byte, blockSize)
97		n, err := io.ReadFull(dec, buf)
98		if err != nil {
99			if err != io.EOF && err != io.ErrUnexpectedEOF {
100				t.Fatal("block", cHash, "err:", err)
101			}
102		}
103		if n > 0 {
104			if cHash+4 > len(hashes) {
105				extra, _ := io.Copy(ioutil.Discard, dec)
106				t.Fatal("not enough hashes (length mismatch). Only have", len(hashes)/4, "hashes. Got block of", n, "bytes and", extra, "bytes still on stream.")
107			}
108			_, _ = xx.Write(buf[:n])
109			binary.LittleEndian.PutUint64(tmp[:], xx.Sum64())
110			want, got := hashes[cHash:cHash+4], tmp[4:]
111			if !bytes.Equal(want, got) {
112				org, err := os.Open(baseFile)
113				if err == nil {
114					const sizeBack = 8 << 20
115					defer org.Close()
116					start := int64(cHash)/4*blockSize - sizeBack
117					if start < 0 {
118						start = 0
119					}
120					_, err = org.Seek(start, io.SeekStart)
121					buf2 := make([]byte, sizeBack+1<<20)
122					n, _ := io.ReadFull(org, buf2)
123					if n > 0 {
124						err = ioutil.WriteFile(baseFile+".section", buf2[:n], os.ModePerm)
125						if err == nil {
126							t.Log("Wrote problematic section to", baseFile+".section")
127						}
128					}
129				}
130
131				t.Fatal("block", cHash/4, "offset", cHash/4*blockSize, "hash mismatch, want:", hex.EncodeToString(want), "got:", hex.EncodeToString(got))
132			}
133			cHash += 4
134		}
135		if n != blockSize {
136			break
137		}
138	}
139	t.Log("Output matched")
140}
141
142func TestNewDecoder(t *testing.T) {
143	defer timeout(60 * time.Second)()
144	testDecoderFile(t, "testdata/decoder.zip")
145	dec, err := NewReader(nil)
146	if err != nil {
147		t.Fatal(err)
148	}
149	testDecoderDecodeAll(t, "testdata/decoder.zip", dec)
150}
151
152func TestNewDecoderGood(t *testing.T) {
153	defer timeout(30 * time.Second)()
154	testDecoderFile(t, "testdata/good.zip")
155	dec, err := NewReader(nil)
156	if err != nil {
157		t.Fatal(err)
158	}
159	testDecoderDecodeAll(t, "testdata/good.zip", dec)
160}
161
162func TestNewDecoderBad(t *testing.T) {
163	defer timeout(10 * time.Second)()
164	dec, err := NewReader(nil)
165	if err != nil {
166		t.Fatal(err)
167	}
168	testDecoderDecodeAllError(t, "testdata/bad.zip", dec)
169}
170
171func TestNewDecoderLarge(t *testing.T) {
172	testDecoderFile(t, "testdata/large.zip")
173	dec, err := NewReader(nil)
174	if err != nil {
175		t.Fatal(err)
176	}
177	testDecoderDecodeAll(t, "testdata/large.zip", dec)
178}
179
180func TestNewReaderRead(t *testing.T) {
181	dec, err := NewReader(nil)
182	if err != nil {
183		t.Fatal(err)
184	}
185	defer dec.Close()
186	_, err = dec.Read([]byte{0})
187	if err == nil {
188		t.Fatal("Wanted error on uninitialized read, got nil")
189	}
190	t.Log("correctly got error", err)
191}
192
193func TestNewDecoderBig(t *testing.T) {
194	if testing.Short() {
195		t.SkipNow()
196	}
197	file := "testdata/zstd-10kfiles.zip"
198	if _, err := os.Stat(file); os.IsNotExist(err) {
199		t.Skip("To run extended tests, download https://files.klauspost.com/compress/zstd-10kfiles.zip \n" +
200			"and place it in " + file + "\n" + "Running it requires about 5GB of RAM")
201	}
202	testDecoderFile(t, file)
203	dec, err := NewReader(nil)
204	if err != nil {
205		t.Fatal(err)
206	}
207	testDecoderDecodeAll(t, file, dec)
208}
209
210func TestNewDecoderBigFile(t *testing.T) {
211	if testing.Short() {
212		t.SkipNow()
213	}
214	file := "testdata/enwik9.zst"
215	const wantSize = 1000000000
216	if _, err := os.Stat(file); os.IsNotExist(err) {
217		t.Skip("To run extended tests, download http://mattmahoney.net/dc/enwik9.zip unzip it \n" +
218			"compress it with 'zstd -15 -T0 enwik9' and place it in " + file)
219	}
220	f, err := os.Open(file)
221	if err != nil {
222		t.Fatal(err)
223	}
224	defer f.Close()
225	start := time.Now()
226	dec, err := NewReader(f)
227	if err != nil {
228		t.Fatal(err)
229	}
230	n, err := io.Copy(ioutil.Discard, dec)
231	if err != nil {
232		t.Fatal(err)
233	}
234	if n != wantSize {
235		t.Errorf("want size %d, got size %d", wantSize, n)
236	}
237	elapsed := time.Since(start)
238	mbpersec := (float64(n) / (1024 * 1024)) / (float64(elapsed) / (float64(time.Second)))
239	t.Logf("Decoded %d bytes with %f.2 MB/s", n, mbpersec)
240}
241
242func TestNewDecoderSmallFile(t *testing.T) {
243	if testing.Short() {
244		t.SkipNow()
245	}
246	file := "testdata/z000028.zst"
247	const wantSize = 39807
248	f, err := os.Open(file)
249	if err != nil {
250		t.Fatal(err)
251	}
252	defer f.Close()
253	start := time.Now()
254	dec, err := NewReader(f)
255	if err != nil {
256		t.Fatal(err)
257	}
258	defer dec.Close()
259	n, err := io.Copy(ioutil.Discard, dec)
260	if err != nil {
261		t.Fatal(err)
262	}
263	if n != wantSize {
264		t.Errorf("want size %d, got size %d", wantSize, n)
265	}
266	mbpersec := (float64(n) / (1024 * 1024)) / (float64(time.Since(start)) / (float64(time.Second)))
267	t.Logf("Decoded %d bytes with %f.2 MB/s", n, mbpersec)
268}
269
270type readAndBlock struct {
271	buf     []byte
272	unblock chan struct{}
273}
274
275func (r *readAndBlock) Read(p []byte) (int, error) {
276	n := copy(p, r.buf)
277	if n == 0 {
278		<-r.unblock
279		return 0, io.EOF
280	}
281	r.buf = r.buf[n:]
282	return n, nil
283}
284
285func TestNewDecoderFlushed(t *testing.T) {
286	if testing.Short() {
287		t.SkipNow()
288	}
289	file := "testdata/z000028.zst"
290	payload, err := ioutil.ReadFile(file)
291	if err != nil {
292		t.Fatal(err)
293	}
294	payload = append(payload, payload...) //2x
295	payload = append(payload, payload...) //4x
296	payload = append(payload, payload...) //8x
297	rng := rand.New(rand.NewSource(0x1337))
298	runs := 100
299	if testing.Short() {
300		runs = 5
301	}
302	enc, err := NewWriter(nil, WithWindowSize(128<<10))
303	if err != nil {
304		t.Fatal(err)
305	}
306	defer enc.Close()
307	for i := 0; i < runs; i++ {
308		wantSize := rng.Intn(len(payload)-1) + 1
309		t.Run(fmt.Sprint("size-", wantSize), func(t *testing.T) {
310			var encoded bytes.Buffer
311			enc.Reset(&encoded)
312			_, err := enc.Write(payload[:wantSize])
313			if err != nil {
314				t.Fatal(err)
315			}
316			err = enc.Flush()
317			if err != nil {
318				t.Fatal(err)
319			}
320
321			// We must be able to read back up until the flush...
322			r := readAndBlock{
323				buf:     encoded.Bytes(),
324				unblock: make(chan struct{}),
325			}
326			defer timeout(5 * time.Second)()
327			dec, err := NewReader(&r)
328			if err != nil {
329				t.Fatal(err)
330			}
331			defer dec.Close()
332			defer close(r.unblock)
333			readBack := 0
334			dst := make([]byte, 1024)
335			for readBack < wantSize {
336				// Read until we have enough.
337				n, err := dec.Read(dst)
338				if err != nil {
339					t.Fatal(err)
340				}
341				readBack += n
342			}
343		})
344	}
345}
346
347func TestDecoderRegression(t *testing.T) {
348	defer timeout(160 * time.Second)()
349	data, err := ioutil.ReadFile("testdata/regression.zip")
350	if err != nil {
351		t.Fatal(err)
352	}
353	zr, err := zip.NewReader(bytes.NewReader(data), int64(len(data)))
354	if err != nil {
355		t.Fatal(err)
356	}
357	dec, err := NewReader(nil, WithDecoderConcurrency(1), WithDecoderLowmem(true), WithDecoderMaxMemory(1<<20))
358	if err != nil {
359		t.Error(err)
360		return
361	}
362	defer dec.Close()
363	for i, tt := range zr.File {
364		if !strings.HasSuffix(tt.Name, "") || (testing.Short() && i > 10) {
365			continue
366		}
367		t.Run("Reader-"+tt.Name, func(t *testing.T) {
368			r, err := tt.Open()
369			if err != nil {
370				t.Error(err)
371				return
372			}
373			err = dec.Reset(r)
374			if err != nil {
375				t.Error(err)
376				return
377			}
378			got, gotErr := ioutil.ReadAll(dec)
379			t.Log("Received:", len(got), gotErr)
380
381			// Check a fresh instance
382			r, err = tt.Open()
383			if err != nil {
384				t.Error(err)
385				return
386			}
387			decL, err := NewReader(r, WithDecoderConcurrency(1), WithDecoderLowmem(true), WithDecoderMaxMemory(1<<20))
388			if err != nil {
389				t.Error(err)
390				return
391			}
392			defer decL.Close()
393			got2, gotErr2 := ioutil.ReadAll(decL)
394			t.Log("Fresh Reader received:", len(got2), gotErr2)
395			if gotErr != gotErr2 {
396				if gotErr != nil && gotErr2 != nil && gotErr.Error() != gotErr2.Error() {
397					t.Error(gotErr, "!=", gotErr2)
398				}
399				if (gotErr == nil) != (gotErr2 == nil) {
400					t.Error(gotErr, "!=", gotErr2)
401				}
402			}
403			if !bytes.Equal(got2, got) {
404				if gotErr != nil {
405					t.Log("Buffer mismatch without Reset")
406				} else {
407					t.Error("Buffer mismatch without Reset")
408				}
409			}
410		})
411		t.Run("DecodeAll-"+tt.Name, func(t *testing.T) {
412			r, err := tt.Open()
413			if err != nil {
414				t.Error(err)
415				return
416			}
417			in, err := ioutil.ReadAll(r)
418			if err != nil {
419				t.Error(err)
420			}
421			got, gotErr := dec.DecodeAll(in, nil)
422			t.Log("Received:", len(got), gotErr)
423
424			// Check if we got the same:
425			decL, err := NewReader(nil, WithDecoderConcurrency(1), WithDecoderLowmem(true), WithDecoderMaxMemory(1<<20))
426			if err != nil {
427				t.Error(err)
428				return
429			}
430			defer decL.Close()
431			got2, gotErr2 := decL.DecodeAll(in, nil)
432			t.Log("Fresh Reader received:", len(got2), gotErr2)
433			if gotErr != gotErr2 {
434				if gotErr != nil && gotErr2 != nil && gotErr.Error() != gotErr2.Error() {
435					t.Error(gotErr, "!=", gotErr2)
436				}
437				if (gotErr == nil) != (gotErr2 == nil) {
438					t.Error(gotErr, "!=", gotErr2)
439				}
440			}
441			if !bytes.Equal(got2, got) {
442				if gotErr != nil {
443					t.Log("Buffer mismatch without Reset")
444				} else {
445					t.Error("Buffer mismatch without Reset")
446				}
447			}
448		})
449		t.Run("Match-"+tt.Name, func(t *testing.T) {
450			r, err := tt.Open()
451			if err != nil {
452				t.Error(err)
453				return
454			}
455			in, err := ioutil.ReadAll(r)
456			if err != nil {
457				t.Error(err)
458			}
459			got, gotErr := dec.DecodeAll(in, nil)
460			t.Log("Received:", len(got), gotErr)
461
462			// Check a fresh instance
463			decL, err := NewReader(bytes.NewBuffer(in), WithDecoderConcurrency(1), WithDecoderLowmem(true), WithDecoderMaxMemory(1<<20))
464			if err != nil {
465				t.Error(err)
466				return
467			}
468			defer decL.Close()
469			got2, gotErr2 := ioutil.ReadAll(decL)
470			t.Log("Reader Reader received:", len(got2), gotErr2)
471			if gotErr != gotErr2 {
472				if gotErr != nil && gotErr2 != nil && gotErr.Error() != gotErr2.Error() {
473					t.Error(gotErr, "!=", gotErr2)
474				}
475				if (gotErr == nil) != (gotErr2 == nil) {
476					t.Error(gotErr, "!=", gotErr2)
477				}
478			}
479			if !bytes.Equal(got2, got) {
480				if gotErr != nil {
481					t.Log("Buffer mismatch")
482				} else {
483					t.Error("Buffer mismatch")
484				}
485			}
486		})
487	}
488}
489
490func TestDecoder_Reset(t *testing.T) {
491	in, err := ioutil.ReadFile("testdata/z000028")
492	if err != nil {
493		t.Fatal(err)
494	}
495	in = append(in, in...)
496	var e Encoder
497	start := time.Now()
498	dst := e.EncodeAll(in, nil)
499	t.Log("Simple Encoder len", len(in), "-> zstd len", len(dst))
500	mbpersec := (float64(len(in)) / (1024 * 1024)) / (float64(time.Since(start)) / (float64(time.Second)))
501	t.Logf("Encoded %d bytes with %.2f MB/s", len(in), mbpersec)
502
503	dec, err := NewReader(nil)
504	if err != nil {
505		t.Fatal(err)
506	}
507	defer dec.Close()
508	decoded, err := dec.DecodeAll(dst, nil)
509	if err != nil {
510		t.Error(err, len(decoded))
511	}
512	if !bytes.Equal(decoded, in) {
513		t.Fatal("Decoded does not match")
514	}
515	t.Log("Encoded content matched")
516
517	// Decode using reset+copy
518	for i := 0; i < 3; i++ {
519		err = dec.Reset(bytes.NewBuffer(dst))
520		if err != nil {
521			t.Fatal(err)
522		}
523		var dBuf bytes.Buffer
524		n, err := io.Copy(&dBuf, dec)
525		if err != nil {
526			t.Fatal(err)
527		}
528		decoded = dBuf.Bytes()
529		if int(n) != len(decoded) {
530			t.Fatalf("decoded reported length mismatch %d != %d", n, len(decoded))
531		}
532		if !bytes.Equal(decoded, in) {
533			ioutil.WriteFile("testdata/"+t.Name()+"-z000028.got", decoded, os.ModePerm)
534			ioutil.WriteFile("testdata/"+t.Name()+"-z000028.want", in, os.ModePerm)
535			t.Fatal("Decoded does not match")
536		}
537	}
538	// Test without WriterTo interface support.
539	for i := 0; i < 3; i++ {
540		err = dec.Reset(bytes.NewBuffer(dst))
541		if err != nil {
542			t.Fatal(err)
543		}
544		decoded, err := ioutil.ReadAll(ioutil.NopCloser(dec))
545		if err != nil {
546			t.Fatal(err)
547		}
548		if !bytes.Equal(decoded, in) {
549			ioutil.WriteFile("testdata/"+t.Name()+"-z000028.got", decoded, os.ModePerm)
550			ioutil.WriteFile("testdata/"+t.Name()+"-z000028.want", in, os.ModePerm)
551			t.Fatal("Decoded does not match")
552		}
553	}
554}
555
556func TestDecoderMultiFrame(t *testing.T) {
557	fn := "testdata/benchdecoder.zip"
558	data, err := ioutil.ReadFile(fn)
559	if err != nil {
560		t.Fatal(err)
561	}
562	zr, err := zip.NewReader(bytes.NewReader(data), int64(len(data)))
563	if err != nil {
564		t.Fatal(err)
565	}
566	dec, err := NewReader(nil)
567	if err != nil {
568		t.Fatal(err)
569		return
570	}
571	defer dec.Close()
572	for _, tt := range zr.File {
573		if !strings.HasSuffix(tt.Name, ".zst") {
574			continue
575		}
576		t.Run(tt.Name, func(t *testing.T) {
577			r, err := tt.Open()
578			if err != nil {
579				t.Fatal(err)
580			}
581			defer r.Close()
582			in, err := ioutil.ReadAll(r)
583			if err != nil {
584				t.Fatal(err)
585			}
586			// 2x
587			in = append(in, in...)
588			if !testing.Short() {
589				// 4x
590				in = append(in, in...)
591				// 8x
592				in = append(in, in...)
593			}
594			err = dec.Reset(bytes.NewBuffer(in))
595			if err != nil {
596				t.Fatal(err)
597			}
598			got, err := ioutil.ReadAll(dec)
599			if err != nil {
600				t.Fatal(err)
601			}
602			err = dec.Reset(bytes.NewBuffer(in))
603			if err != nil {
604				t.Fatal(err)
605			}
606			got2, err := ioutil.ReadAll(dec)
607			if err != nil {
608				t.Fatal(err)
609			}
610			if !bytes.Equal(got, got2) {
611				t.Error("results mismatch")
612			}
613		})
614	}
615}
616
617func TestDecoderMultiFrameReset(t *testing.T) {
618	fn := "testdata/benchdecoder.zip"
619	data, err := ioutil.ReadFile(fn)
620	if err != nil {
621		t.Fatal(err)
622	}
623	zr, err := zip.NewReader(bytes.NewReader(data), int64(len(data)))
624	if err != nil {
625		t.Fatal(err)
626	}
627	dec, err := NewReader(nil)
628	if err != nil {
629		t.Fatal(err)
630		return
631	}
632	rng := rand.New(rand.NewSource(1337))
633	defer dec.Close()
634	for _, tt := range zr.File {
635		if !strings.HasSuffix(tt.Name, ".zst") {
636			continue
637		}
638		t.Run(tt.Name, func(t *testing.T) {
639			r, err := tt.Open()
640			if err != nil {
641				t.Fatal(err)
642			}
643			defer r.Close()
644			in, err := ioutil.ReadAll(r)
645			if err != nil {
646				t.Fatal(err)
647			}
648			// 2x
649			in = append(in, in...)
650			if !testing.Short() {
651				// 4x
652				in = append(in, in...)
653				// 8x
654				in = append(in, in...)
655			}
656			err = dec.Reset(bytes.NewBuffer(in))
657			if err != nil {
658				t.Fatal(err)
659			}
660			got, err := ioutil.ReadAll(dec)
661			if err != nil {
662				t.Fatal(err)
663			}
664			err = dec.Reset(bytes.NewBuffer(in))
665			if err != nil {
666				t.Fatal(err)
667			}
668			// Read a random number of bytes
669			tmp := make([]byte, rng.Intn(len(got)))
670			_, err = io.ReadAtLeast(dec, tmp, len(tmp))
671			if err != nil {
672				t.Fatal(err)
673			}
674			err = dec.Reset(bytes.NewBuffer(in))
675			if err != nil {
676				t.Fatal(err)
677			}
678			got2, err := ioutil.ReadAll(dec)
679			if err != nil {
680				t.Fatal(err)
681			}
682			if !bytes.Equal(got, got2) {
683				t.Error("results mismatch")
684			}
685		})
686	}
687}
688
689func testDecoderFile(t *testing.T, fn string) {
690	data, err := ioutil.ReadFile(fn)
691	if err != nil {
692		t.Fatal(err)
693	}
694	zr, err := zip.NewReader(bytes.NewReader(data), int64(len(data)))
695	if err != nil {
696		t.Fatal(err)
697	}
698	var want = make(map[string][]byte)
699	for _, tt := range zr.File {
700		if strings.HasSuffix(tt.Name, ".zst") {
701			continue
702		}
703		r, err := tt.Open()
704		if err != nil {
705			t.Fatal(err)
706			return
707		}
708		want[tt.Name+".zst"], _ = ioutil.ReadAll(r)
709	}
710
711	dec, err := NewReader(nil)
712	if err != nil {
713		t.Error(err)
714		return
715	}
716	defer dec.Close()
717	for i, tt := range zr.File {
718		if !strings.HasSuffix(tt.Name, ".zst") || (testing.Short() && i > 20) {
719			continue
720		}
721		t.Run("Reader-"+tt.Name, func(t *testing.T) {
722			r, err := tt.Open()
723			if err != nil {
724				t.Error(err)
725				return
726			}
727			defer r.Close()
728			err = dec.Reset(r)
729			if err != nil {
730				t.Error(err)
731				return
732			}
733			got, err := ioutil.ReadAll(dec)
734			if err != nil {
735				t.Error(err)
736				if err != ErrCRCMismatch {
737					return
738				}
739			}
740			wantB := want[tt.Name]
741			if !bytes.Equal(wantB, got) {
742				if len(wantB)+len(got) < 1000 {
743					t.Logf(" got: %v\nwant: %v", got, wantB)
744				} else {
745					fileName, _ := filepath.Abs(filepath.Join("testdata", t.Name()+"-want.bin"))
746					_ = os.MkdirAll(filepath.Dir(fileName), os.ModePerm)
747					err := ioutil.WriteFile(fileName, wantB, os.ModePerm)
748					t.Log("Wrote file", fileName, err)
749
750					fileName, _ = filepath.Abs(filepath.Join("testdata", t.Name()+"-got.bin"))
751					_ = os.MkdirAll(filepath.Dir(fileName), os.ModePerm)
752					err = ioutil.WriteFile(fileName, got, os.ModePerm)
753					t.Log("Wrote file", fileName, err)
754				}
755				t.Logf("Length, want: %d, got: %d", len(wantB), len(got))
756				t.Error("Output mismatch")
757				return
758			}
759			t.Log(len(got), "bytes returned, matches input, ok!")
760		})
761	}
762}
763
764func BenchmarkDecoder_DecoderSmall(b *testing.B) {
765	fn := "testdata/benchdecoder.zip"
766	data, err := ioutil.ReadFile(fn)
767	if err != nil {
768		b.Fatal(err)
769	}
770	zr, err := zip.NewReader(bytes.NewReader(data), int64(len(data)))
771	if err != nil {
772		b.Fatal(err)
773	}
774	dec, err := NewReader(nil)
775	if err != nil {
776		b.Fatal(err)
777		return
778	}
779	defer dec.Close()
780	for _, tt := range zr.File {
781		if !strings.HasSuffix(tt.Name, ".zst") {
782			continue
783		}
784		b.Run(tt.Name, func(b *testing.B) {
785			r, err := tt.Open()
786			if err != nil {
787				b.Fatal(err)
788			}
789			defer r.Close()
790			in, err := ioutil.ReadAll(r)
791			if err != nil {
792				b.Fatal(err)
793			}
794			// 2x
795			in = append(in, in...)
796			// 4x
797			in = append(in, in...)
798			// 8x
799			in = append(in, in...)
800			err = dec.Reset(bytes.NewBuffer(in))
801			if err != nil {
802				b.Fatal(err)
803			}
804			got, err := ioutil.ReadAll(dec)
805			if err != nil {
806				b.Fatal(err)
807			}
808			b.SetBytes(int64(len(got)))
809			b.ReportAllocs()
810			b.ResetTimer()
811			for i := 0; i < b.N; i++ {
812				err = dec.Reset(bytes.NewBuffer(in))
813				if err != nil {
814					b.Fatal(err)
815				}
816				_, err := io.Copy(ioutil.Discard, dec)
817				if err != nil {
818					b.Fatal(err)
819				}
820			}
821		})
822	}
823}
824
825func BenchmarkDecoder_DecodeAll(b *testing.B) {
826	fn := "testdata/benchdecoder.zip"
827	data, err := ioutil.ReadFile(fn)
828	if err != nil {
829		b.Fatal(err)
830	}
831	zr, err := zip.NewReader(bytes.NewReader(data), int64(len(data)))
832	if err != nil {
833		b.Fatal(err)
834	}
835	dec, err := NewReader(nil, WithDecoderConcurrency(1))
836	if err != nil {
837		b.Fatal(err)
838		return
839	}
840	defer dec.Close()
841	for _, tt := range zr.File {
842		if !strings.HasSuffix(tt.Name, ".zst") {
843			continue
844		}
845		b.Run(tt.Name, func(b *testing.B) {
846			r, err := tt.Open()
847			if err != nil {
848				b.Fatal(err)
849			}
850			defer r.Close()
851			in, err := ioutil.ReadAll(r)
852			if err != nil {
853				b.Fatal(err)
854			}
855			got, err := dec.DecodeAll(in, nil)
856			if err != nil {
857				b.Fatal(err)
858			}
859			b.SetBytes(int64(len(got)))
860			b.ReportAllocs()
861			b.ResetTimer()
862			for i := 0; i < b.N; i++ {
863				_, err = dec.DecodeAll(in, got[:0])
864				if err != nil {
865					b.Fatal(err)
866				}
867			}
868		})
869	}
870}
871
872/*
873func BenchmarkDecoder_DecodeAllCgo(b *testing.B) {
874	fn := "testdata/benchdecoder.zip"
875	data, err := ioutil.ReadFile(fn)
876	if err != nil {
877		b.Fatal(err)
878	}
879	zr, err := zip.NewReader(bytes.NewReader(data), int64(len(data)))
880	if err != nil {
881		b.Fatal(err)
882	}
883	for _, tt := range zr.File {
884		if !strings.HasSuffix(tt.Name, ".zst") {
885			continue
886		}
887		b.Run(tt.Name, func(b *testing.B) {
888			tt := tt
889			r, err := tt.Open()
890			if err != nil {
891				b.Fatal(err)
892			}
893			defer r.Close()
894			in, err := ioutil.ReadAll(r)
895			if err != nil {
896				b.Fatal(err)
897			}
898			got, err := zstd.Decompress(nil, in)
899			if err != nil {
900				b.Fatal(err)
901			}
902			b.SetBytes(int64(len(got)))
903			b.ReportAllocs()
904			b.ResetTimer()
905			for i := 0; i < b.N; i++ {
906				got, err = zstd.Decompress(got[:0], in)
907				if err != nil {
908					b.Fatal(err)
909				}
910			}
911		})
912	}
913}
914
915func BenchmarkDecoderSilesiaCgo(b *testing.B) {
916	fn := "testdata/silesia.tar.zst"
917	data, err := ioutil.ReadFile(fn)
918	if err != nil {
919		if os.IsNotExist(err) {
920			b.Skip("Missing testdata/silesia.tar.zst")
921			return
922		}
923		b.Fatal(err)
924	}
925	dec := zstd.NewReader(bytes.NewBuffer(data))
926	n, err := io.Copy(ioutil.Discard, dec)
927	if err != nil {
928		b.Fatal(err)
929	}
930
931	b.SetBytes(n)
932	b.ReportAllocs()
933	b.ResetTimer()
934	for i := 0; i < b.N; i++ {
935		dec := zstd.NewReader(bytes.NewBuffer(data))
936		_, err := io.CopyN(ioutil.Discard, dec, n)
937		if err != nil {
938			b.Fatal(err)
939		}
940	}
941}
942func BenchmarkDecoderEnwik9Cgo(b *testing.B) {
943	fn := "testdata/enwik9-1.zst"
944	data, err := ioutil.ReadFile(fn)
945	if err != nil {
946		if os.IsNotExist(err) {
947			b.Skip("Missing " + fn)
948			return
949		}
950		b.Fatal(err)
951	}
952	dec := zstd.NewReader(bytes.NewBuffer(data))
953	n, err := io.Copy(ioutil.Discard, dec)
954	if err != nil {
955		b.Fatal(err)
956	}
957
958	b.SetBytes(n)
959	b.ReportAllocs()
960	b.ResetTimer()
961	for i := 0; i < b.N; i++ {
962		dec := zstd.NewReader(bytes.NewBuffer(data))
963		_, err := io.CopyN(ioutil.Discard, dec, n)
964		if err != nil {
965			b.Fatal(err)
966		}
967	}
968}
969
970*/
971
972func BenchmarkDecoderSilesia(b *testing.B) {
973	fn := "testdata/silesia.tar.zst"
974	data, err := ioutil.ReadFile(fn)
975	if err != nil {
976		if os.IsNotExist(err) {
977			b.Skip("Missing testdata/silesia.tar.zst")
978			return
979		}
980		b.Fatal(err)
981	}
982	dec, err := NewReader(nil, WithDecoderLowmem(false))
983	if err != nil {
984		b.Fatal(err)
985	}
986	defer dec.Close()
987	err = dec.Reset(bytes.NewBuffer(data))
988	if err != nil {
989		b.Fatal(err)
990	}
991	n, err := io.Copy(ioutil.Discard, dec)
992	if err != nil {
993		b.Fatal(err)
994	}
995
996	b.SetBytes(n)
997	b.ReportAllocs()
998	b.ResetTimer()
999	for i := 0; i < b.N; i++ {
1000		err = dec.Reset(bytes.NewBuffer(data))
1001		if err != nil {
1002			b.Fatal(err)
1003		}
1004		_, err := io.CopyN(ioutil.Discard, dec, n)
1005		if err != nil {
1006			b.Fatal(err)
1007		}
1008	}
1009}
1010
1011func BenchmarkDecoderEnwik9(b *testing.B) {
1012	fn := "testdata/enwik9-1.zst"
1013	data, err := ioutil.ReadFile(fn)
1014	if err != nil {
1015		if os.IsNotExist(err) {
1016			b.Skip("Missing " + fn)
1017			return
1018		}
1019		b.Fatal(err)
1020	}
1021	dec, err := NewReader(nil, WithDecoderLowmem(false))
1022	if err != nil {
1023		b.Fatal(err)
1024	}
1025	defer dec.Close()
1026	err = dec.Reset(bytes.NewBuffer(data))
1027	if err != nil {
1028		b.Fatal(err)
1029	}
1030	n, err := io.Copy(ioutil.Discard, dec)
1031	if err != nil {
1032		b.Fatal(err)
1033	}
1034
1035	b.SetBytes(n)
1036	b.ReportAllocs()
1037	b.ResetTimer()
1038	for i := 0; i < b.N; i++ {
1039		err = dec.Reset(bytes.NewBuffer(data))
1040		if err != nil {
1041			b.Fatal(err)
1042		}
1043		_, err := io.CopyN(ioutil.Discard, dec, n)
1044		if err != nil {
1045			b.Fatal(err)
1046		}
1047	}
1048}
1049
1050func testDecoderDecodeAll(t *testing.T, fn string, dec *Decoder) {
1051	data, err := ioutil.ReadFile(fn)
1052	if err != nil {
1053		t.Fatal(err)
1054	}
1055	zr, err := zip.NewReader(bytes.NewReader(data), int64(len(data)))
1056	if err != nil {
1057		t.Fatal(err)
1058	}
1059	var want = make(map[string][]byte)
1060	for _, tt := range zr.File {
1061		if strings.HasSuffix(tt.Name, ".zst") {
1062			continue
1063		}
1064		r, err := tt.Open()
1065		if err != nil {
1066			t.Fatal(err)
1067			return
1068		}
1069		want[tt.Name+".zst"], _ = ioutil.ReadAll(r)
1070	}
1071	var wg sync.WaitGroup
1072	for i, tt := range zr.File {
1073		tt := tt
1074		if !strings.HasSuffix(tt.Name, ".zst") || (testing.Short() && i > 20) {
1075			continue
1076		}
1077		wg.Add(1)
1078		t.Run("DecodeAll-"+tt.Name, func(t *testing.T) {
1079			defer wg.Done()
1080			t.Parallel()
1081			r, err := tt.Open()
1082			if err != nil {
1083				t.Fatal(err)
1084			}
1085			in, err := ioutil.ReadAll(r)
1086			if err != nil {
1087				t.Fatal(err)
1088			}
1089			wantB := want[tt.Name]
1090			// make a buffer that is too small.
1091			got, err := dec.DecodeAll(in, make([]byte, 10, 200))
1092			if err != nil {
1093				t.Error(err)
1094			}
1095			got = got[10:]
1096			if !bytes.Equal(wantB, got) {
1097				if len(wantB)+len(got) < 1000 {
1098					t.Logf(" got: %v\nwant: %v", got, wantB)
1099				} else {
1100					fileName, _ := filepath.Abs(filepath.Join("testdata", t.Name()+"-want.bin"))
1101					_ = os.MkdirAll(filepath.Dir(fileName), os.ModePerm)
1102					err := ioutil.WriteFile(fileName, wantB, os.ModePerm)
1103					t.Log("Wrote file", fileName, err)
1104
1105					fileName, _ = filepath.Abs(filepath.Join("testdata", t.Name()+"-got.bin"))
1106					_ = os.MkdirAll(filepath.Dir(fileName), os.ModePerm)
1107					err = ioutil.WriteFile(fileName, got, os.ModePerm)
1108					t.Log("Wrote file", fileName, err)
1109				}
1110				t.Logf("Length, want: %d, got: %d", len(wantB), len(got))
1111				t.Error("Output mismatch")
1112				return
1113			}
1114			t.Log(len(got), "bytes returned, matches input, ok!")
1115		})
1116	}
1117	go func() {
1118		wg.Wait()
1119		dec.Close()
1120	}()
1121}
1122
1123func testDecoderDecodeAllError(t *testing.T, fn string, dec *Decoder) {
1124	data, err := ioutil.ReadFile(fn)
1125	if err != nil {
1126		t.Fatal(err)
1127	}
1128	zr, err := zip.NewReader(bytes.NewReader(data), int64(len(data)))
1129	if err != nil {
1130		t.Fatal(err)
1131	}
1132
1133	var wg sync.WaitGroup
1134	for _, tt := range zr.File {
1135		tt := tt
1136		if !strings.HasSuffix(tt.Name, ".zst") {
1137			continue
1138		}
1139		wg.Add(1)
1140		t.Run("DecodeAll-"+tt.Name, func(t *testing.T) {
1141			defer wg.Done()
1142			t.Parallel()
1143			r, err := tt.Open()
1144			if err != nil {
1145				t.Fatal(err)
1146			}
1147			in, err := ioutil.ReadAll(r)
1148			if err != nil {
1149				t.Fatal(err)
1150			}
1151			// make a buffer that is too small.
1152			_, err = dec.DecodeAll(in, make([]byte, 0, 200))
1153			if err == nil {
1154				t.Error("Did not get expected error")
1155			}
1156		})
1157	}
1158	go func() {
1159		wg.Wait()
1160		dec.Close()
1161	}()
1162}
1163
1164// Test our predefined tables are correct.
1165// We don't predefine them, since this also tests our transformations.
1166// Reference from here: https://github.com/facebook/zstd/blob/ededcfca57366461021c922720878c81a5854a0a/lib/decompress/zstd_decompress_block.c#L234
1167func TestPredefTables(t *testing.T) {
1168	x := func(nextState uint16, nbAddBits, nbBits uint8, baseVal uint32) decSymbol {
1169		return newDecSymbol(nbBits, nbAddBits, nextState, baseVal)
1170	}
1171	for i := range fsePredef[:] {
1172		var want []decSymbol
1173		switch tableIndex(i) {
1174		case tableLiteralLengths:
1175			want = []decSymbol{
1176				/* nextState, nbAddBits, nbBits, baseVal */
1177				x(0, 0, 4, 0), x(16, 0, 4, 0),
1178				x(32, 0, 5, 1), x(0, 0, 5, 3),
1179				x(0, 0, 5, 4), x(0, 0, 5, 6),
1180				x(0, 0, 5, 7), x(0, 0, 5, 9),
1181				x(0, 0, 5, 10), x(0, 0, 5, 12),
1182				x(0, 0, 6, 14), x(0, 1, 5, 16),
1183				x(0, 1, 5, 20), x(0, 1, 5, 22),
1184				x(0, 2, 5, 28), x(0, 3, 5, 32),
1185				x(0, 4, 5, 48), x(32, 6, 5, 64),
1186				x(0, 7, 5, 128), x(0, 8, 6, 256),
1187				x(0, 10, 6, 1024), x(0, 12, 6, 4096),
1188				x(32, 0, 4, 0), x(0, 0, 4, 1),
1189				x(0, 0, 5, 2), x(32, 0, 5, 4),
1190				x(0, 0, 5, 5), x(32, 0, 5, 7),
1191				x(0, 0, 5, 8), x(32, 0, 5, 10),
1192				x(0, 0, 5, 11), x(0, 0, 6, 13),
1193				x(32, 1, 5, 16), x(0, 1, 5, 18),
1194				x(32, 1, 5, 22), x(0, 2, 5, 24),
1195				x(32, 3, 5, 32), x(0, 3, 5, 40),
1196				x(0, 6, 4, 64), x(16, 6, 4, 64),
1197				x(32, 7, 5, 128), x(0, 9, 6, 512),
1198				x(0, 11, 6, 2048), x(48, 0, 4, 0),
1199				x(16, 0, 4, 1), x(32, 0, 5, 2),
1200				x(32, 0, 5, 3), x(32, 0, 5, 5),
1201				x(32, 0, 5, 6), x(32, 0, 5, 8),
1202				x(32, 0, 5, 9), x(32, 0, 5, 11),
1203				x(32, 0, 5, 12), x(0, 0, 6, 15),
1204				x(32, 1, 5, 18), x(32, 1, 5, 20),
1205				x(32, 2, 5, 24), x(32, 2, 5, 28),
1206				x(32, 3, 5, 40), x(32, 4, 5, 48),
1207				x(0, 16, 6, 65536), x(0, 15, 6, 32768),
1208				x(0, 14, 6, 16384), x(0, 13, 6, 8192),
1209			}
1210		case tableOffsets:
1211			want = []decSymbol{
1212				/* nextState, nbAddBits, nbBits, baseVal */
1213				x(0, 0, 5, 0), x(0, 6, 4, 61),
1214				x(0, 9, 5, 509), x(0, 15, 5, 32765),
1215				x(0, 21, 5, 2097149), x(0, 3, 5, 5),
1216				x(0, 7, 4, 125), x(0, 12, 5, 4093),
1217				x(0, 18, 5, 262141), x(0, 23, 5, 8388605),
1218				x(0, 5, 5, 29), x(0, 8, 4, 253),
1219				x(0, 14, 5, 16381), x(0, 20, 5, 1048573),
1220				x(0, 2, 5, 1), x(16, 7, 4, 125),
1221				x(0, 11, 5, 2045), x(0, 17, 5, 131069),
1222				x(0, 22, 5, 4194301), x(0, 4, 5, 13),
1223				x(16, 8, 4, 253), x(0, 13, 5, 8189),
1224				x(0, 19, 5, 524285), x(0, 1, 5, 1),
1225				x(16, 6, 4, 61), x(0, 10, 5, 1021),
1226				x(0, 16, 5, 65533), x(0, 28, 5, 268435453),
1227				x(0, 27, 5, 134217725), x(0, 26, 5, 67108861),
1228				x(0, 25, 5, 33554429), x(0, 24, 5, 16777213),
1229			}
1230		case tableMatchLengths:
1231			want = []decSymbol{
1232				/* nextState, nbAddBits, nbBits, baseVal */
1233				x(0, 0, 6, 3), x(0, 0, 4, 4),
1234				x(32, 0, 5, 5), x(0, 0, 5, 6),
1235				x(0, 0, 5, 8), x(0, 0, 5, 9),
1236				x(0, 0, 5, 11), x(0, 0, 6, 13),
1237				x(0, 0, 6, 16), x(0, 0, 6, 19),
1238				x(0, 0, 6, 22), x(0, 0, 6, 25),
1239				x(0, 0, 6, 28), x(0, 0, 6, 31),
1240				x(0, 0, 6, 34), x(0, 1, 6, 37),
1241				x(0, 1, 6, 41), x(0, 2, 6, 47),
1242				x(0, 3, 6, 59), x(0, 4, 6, 83),
1243				x(0, 7, 6, 131), x(0, 9, 6, 515),
1244				x(16, 0, 4, 4), x(0, 0, 4, 5),
1245				x(32, 0, 5, 6), x(0, 0, 5, 7),
1246				x(32, 0, 5, 9), x(0, 0, 5, 10),
1247				x(0, 0, 6, 12), x(0, 0, 6, 15),
1248				x(0, 0, 6, 18), x(0, 0, 6, 21),
1249				x(0, 0, 6, 24), x(0, 0, 6, 27),
1250				x(0, 0, 6, 30), x(0, 0, 6, 33),
1251				x(0, 1, 6, 35), x(0, 1, 6, 39),
1252				x(0, 2, 6, 43), x(0, 3, 6, 51),
1253				x(0, 4, 6, 67), x(0, 5, 6, 99),
1254				x(0, 8, 6, 259), x(32, 0, 4, 4),
1255				x(48, 0, 4, 4), x(16, 0, 4, 5),
1256				x(32, 0, 5, 7), x(32, 0, 5, 8),
1257				x(32, 0, 5, 10), x(32, 0, 5, 11),
1258				x(0, 0, 6, 14), x(0, 0, 6, 17),
1259				x(0, 0, 6, 20), x(0, 0, 6, 23),
1260				x(0, 0, 6, 26), x(0, 0, 6, 29),
1261				x(0, 0, 6, 32), x(0, 16, 6, 65539),
1262				x(0, 15, 6, 32771), x(0, 14, 6, 16387),
1263				x(0, 13, 6, 8195), x(0, 12, 6, 4099),
1264				x(0, 11, 6, 2051), x(0, 10, 6, 1027),
1265			}
1266		}
1267		pre := fsePredef[i]
1268		got := pre.dt[:1<<pre.actualTableLog]
1269		if !reflect.DeepEqual(got, want) {
1270			t.Logf("want: %v", want)
1271			t.Logf("got : %v", got)
1272			t.Errorf("Predefined table %d incorrect, len(got) = %d, len(want) = %d", i, len(got), len(want))
1273		}
1274	}
1275}
1276
1277func timeout(after time.Duration) (cancel func()) {
1278	c := time.After(after)
1279	cc := make(chan struct{})
1280	go func() {
1281		select {
1282		case <-cc:
1283			return
1284		case <-c:
1285			buf := make([]byte, 1<<20)
1286			stacklen := runtime.Stack(buf, true)
1287			log.Printf("=== Timeout, assuming deadlock ===\n*** goroutine dump...\n%s\n*** end\n", string(buf[:stacklen]))
1288			os.Exit(2)
1289		}
1290	}()
1291	return func() {
1292		close(cc)
1293	}
1294}
1295