1// Copyright 2019+ Klaus Post. All rights reserved.
2// License information can be found in the LICENSE file.
3// Based on work by Yann Collet, released under BSD License.
4
5package zstd
6
7import (
8	"bytes"
9	"fmt"
10	"io"
11	"io/ioutil"
12	"math/rand"
13	"os"
14	"runtime"
15	"strings"
16	"sync"
17	"testing"
18	"time"
19
20	"github.com/klauspost/compress/zip"
21	"github.com/klauspost/compress/zstd/internal/xxhash"
22)
23
24var testWindowSizes = []int{MinWindowSize, 1 << 16, 1 << 22, 1 << 24}
25
26type testEncOpt struct {
27	name string
28	o    []EOption
29}
30
31func getEncOpts(cMax int) []testEncOpt {
32	var o []testEncOpt
33	for level := speedNotSet + 1; level < speedLast; level++ {
34		for conc := 1; conc <= 4; conc *= 2 {
35			for _, wind := range testWindowSizes {
36				addOpt := func(name string, options ...EOption) {
37					opts := append([]EOption(nil), WithEncoderLevel(level), WithEncoderConcurrency(conc), WithWindowSize(wind))
38					name = fmt.Sprintf("%s-c%d-w%dk-%s", level.String(), conc, wind/1024, name)
39					o = append(o, testEncOpt{name: name, o: append(opts, options...)})
40				}
41				addOpt("default")
42				if testing.Short() {
43					break
44				}
45				addOpt("nocrc", WithEncoderCRC(false))
46				addOpt("lowmem", WithLowerEncoderMem(true))
47				addOpt("alllit", WithAllLitEntropyCompression(true))
48				addOpt("nolit", WithNoEntropyCompression(true))
49				addOpt("pad1k", WithEncoderPadding(1024))
50				addOpt("zerof", WithZeroFrames(true))
51				addOpt("singleseg", WithSingleSegment(true))
52			}
53			if testing.Short() && conc == 2 {
54				break
55			}
56			if conc >= cMax {
57				break
58			}
59		}
60	}
61	return o
62}
63
64func TestEncoder_EncodeAllSimple(t *testing.T) {
65	in, err := ioutil.ReadFile("testdata/z000028")
66	if err != nil {
67		t.Fatal(err)
68	}
69	dec, err := NewReader(nil)
70	if err != nil {
71		t.Fatal(err)
72	}
73	defer dec.Close()
74
75	in = append(in, in...)
76	for _, opts := range getEncOpts(4) {
77		t.Run(opts.name, func(t *testing.T) {
78			e, err := NewWriter(nil, opts.o...)
79			if err != nil {
80				t.Fatal(err)
81			}
82			defer e.Close()
83			start := time.Now()
84			dst := e.EncodeAll(in, nil)
85			t.Log("Simple Encoder len", len(in), "-> zstd len", len(dst))
86			mbpersec := (float64(len(in)) / (1024 * 1024)) / (float64(time.Since(start)) / (float64(time.Second)))
87			t.Logf("Encoded %d bytes with %.2f MB/s", len(in), mbpersec)
88
89			decoded, err := dec.DecodeAll(dst, nil)
90			if err != nil {
91				t.Error(err, len(decoded))
92			}
93			if !bytes.Equal(decoded, in) {
94				ioutil.WriteFile("testdata/"+t.Name()+"-z000028.got", decoded, os.ModePerm)
95				ioutil.WriteFile("testdata/"+t.Name()+"-z000028.want", in, os.ModePerm)
96				t.Fatal("Decoded does not match")
97			}
98			t.Log("Encoded content matched")
99		})
100	}
101}
102
103func TestEncoder_EncodeAllConcurrent(t *testing.T) {
104	in, err := ioutil.ReadFile("testdata/z000028")
105	if err != nil {
106		t.Fatal(err)
107	}
108	in = append(in, in...)
109
110	// When running race no more than 8k goroutines allowed.
111	n := 400 / runtime.GOMAXPROCS(0)
112	if testing.Short() {
113		n = 20 / runtime.GOMAXPROCS(0)
114	}
115	dec, err := NewReader(nil)
116	if err != nil {
117		t.Fatal(err)
118	}
119	defer dec.Close()
120	for _, opts := range getEncOpts(2) {
121		t.Run(opts.name, func(t *testing.T) {
122			rng := rand.New(rand.NewSource(0x1337))
123			e, err := NewWriter(nil, opts.o...)
124			if err != nil {
125				t.Fatal(err)
126			}
127			defer e.Close()
128			var wg sync.WaitGroup
129			wg.Add(n)
130			for i := 0; i < n; i++ {
131				in := in[rng.Int()&1023:]
132				in = in[:rng.Intn(len(in))]
133				go func() {
134					defer wg.Done()
135					dst := e.EncodeAll(in, nil)
136					//t.Log("Simple Encoder len", len(in), "-> zstd len", len(dst))
137					decoded, err := dec.DecodeAll(dst, nil)
138					if err != nil {
139						t.Error(err, len(decoded))
140					}
141					if !bytes.Equal(decoded, in) {
142						//ioutil.WriteFile("testdata/"+t.Name()+"-z000028.got", decoded, os.ModePerm)
143						//ioutil.WriteFile("testdata/"+t.Name()+"-z000028.want", in, os.ModePerm)
144						t.Error("Decoded does not match")
145						return
146					}
147				}()
148			}
149			wg.Wait()
150			t.Log("Encoded content matched.", n, "goroutines")
151		})
152	}
153}
154
155func TestEncoder_EncodeAllEncodeXML(t *testing.T) {
156	f, err := os.Open("testdata/xml.zst")
157	if err != nil {
158		t.Fatal(err)
159	}
160	dec, err := NewReader(f)
161	if err != nil {
162		t.Fatal(err)
163	}
164	defer dec.Close()
165	in, err := ioutil.ReadAll(dec)
166	if err != nil {
167		t.Fatal(err)
168	}
169	if testing.Short() {
170		in = in[:10000]
171	}
172
173	for level := speedNotSet + 1; level < speedLast; level++ {
174		t.Run(level.String(), func(t *testing.T) {
175			e, err := NewWriter(nil, WithEncoderLevel(level))
176			if err != nil {
177				t.Fatal(err)
178			}
179			defer e.Close()
180			start := time.Now()
181			dst := e.EncodeAll(in, nil)
182			t.Log("Simple Encoder len", len(in), "-> zstd len", len(dst))
183			mbpersec := (float64(len(in)) / (1024 * 1024)) / (float64(time.Since(start)) / (float64(time.Second)))
184			t.Logf("Encoded %d bytes with %.2f MB/s", len(in), mbpersec)
185
186			decoded, err := dec.DecodeAll(dst, nil)
187			if err != nil {
188				t.Error(err, len(decoded))
189			}
190			if !bytes.Equal(decoded, in) {
191				ioutil.WriteFile("testdata/"+t.Name()+"-xml.got", decoded, os.ModePerm)
192				t.Error("Decoded does not match")
193				return
194			}
195			t.Log("Encoded content matched")
196		})
197	}
198}
199
200func TestEncoderRegression(t *testing.T) {
201	defer timeout(4 * time.Minute)()
202	data, err := ioutil.ReadFile("testdata/comp-crashers.zip")
203	if err != nil {
204		t.Fatal(err)
205	}
206	// We can't close the decoder.
207	dec, err := NewReader(nil)
208	if err != nil {
209		t.Error(err)
210		return
211	}
212	defer dec.Close()
213	for _, opts := range getEncOpts(2) {
214		t.Run(opts.name, func(t *testing.T) {
215			zr, err := zip.NewReader(bytes.NewReader(data), int64(len(data)))
216			if err != nil {
217				t.Fatal(err)
218			}
219			enc, err := NewWriter(
220				nil,
221				opts.o...,
222			)
223			if err != nil {
224				t.Fatal(err)
225			}
226			defer enc.Close()
227
228			for i, tt := range zr.File {
229				if !strings.HasSuffix(t.Name(), "") {
230					continue
231				}
232				if testing.Short() && i > 10 {
233					break
234				}
235
236				t.Run(tt.Name, func(t *testing.T) {
237					r, err := tt.Open()
238					if err != nil {
239						t.Error(err)
240						return
241					}
242					in, err := ioutil.ReadAll(r)
243					if err != nil {
244						t.Error(err)
245					}
246					encoded := enc.EncodeAll(in, nil)
247					got, err := dec.DecodeAll(encoded, nil)
248					if err != nil {
249						t.Logf("error: %v\nwant: %v\ngot:  %v", err, len(in), len(got))
250						t.Fatal(err)
251					}
252
253					// Use the Writer
254					var dst bytes.Buffer
255					enc.Reset(&dst)
256					_, err = enc.Write(in)
257					if err != nil {
258						t.Error(err)
259					}
260					err = enc.Close()
261					if err != nil {
262						t.Error(err)
263					}
264					encoded = dst.Bytes()
265					got, err = dec.DecodeAll(encoded, nil)
266					if err != nil {
267						t.Logf("error: %v\nwant: %v\ngot:  %v", err, in, got)
268						t.Error(err)
269					}
270				})
271			}
272		})
273	}
274}
275
276func TestEncoder_EncodeAllTwain(t *testing.T) {
277	in, err := ioutil.ReadFile("../testdata/Mark.Twain-Tom.Sawyer.txt")
278	if err != nil {
279		t.Fatal(err)
280	}
281	testWindowSizes := testWindowSizes
282	if testing.Short() {
283		testWindowSizes = []int{1 << 20}
284	}
285
286	dec, err := NewReader(nil)
287	if err != nil {
288		t.Fatal(err)
289	}
290	defer dec.Close()
291
292	for level := speedNotSet + 1; level < speedLast; level++ {
293		t.Run(level.String(), func(t *testing.T) {
294			for _, windowSize := range testWindowSizes {
295				t.Run(fmt.Sprintf("window:%d", windowSize), func(t *testing.T) {
296					e, err := NewWriter(nil, WithEncoderLevel(level), WithWindowSize(windowSize))
297					if err != nil {
298						t.Fatal(err)
299					}
300					defer e.Close()
301					start := time.Now()
302					dst := e.EncodeAll(in, nil)
303					t.Log("Simple Encoder len", len(in), "-> zstd len", len(dst))
304					mbpersec := (float64(len(in)) / (1024 * 1024)) / (float64(time.Since(start)) / (float64(time.Second)))
305					t.Logf("Encoded %d bytes with %.2f MB/s", len(in), mbpersec)
306
307					decoded, err := dec.DecodeAll(dst, nil)
308					if err != nil {
309						t.Error(err, len(decoded))
310					}
311					if !bytes.Equal(decoded, in) {
312						ioutil.WriteFile("testdata/"+t.Name()+"-Mark.Twain-Tom.Sawyer.txt.got", decoded, os.ModePerm)
313						t.Fatal("Decoded does not match")
314					}
315					t.Log("Encoded content matched")
316				})
317			}
318		})
319	}
320}
321
322func TestEncoder_EncodeAllPi(t *testing.T) {
323	in, err := ioutil.ReadFile("../testdata/pi.txt")
324	if err != nil {
325		t.Fatal(err)
326	}
327	testWindowSizes := testWindowSizes
328	if testing.Short() {
329		testWindowSizes = []int{1 << 20}
330	}
331
332	dec, err := NewReader(nil)
333	if err != nil {
334		t.Fatal(err)
335	}
336	defer dec.Close()
337
338	for level := speedNotSet + 1; level < speedLast; level++ {
339		t.Run(level.String(), func(t *testing.T) {
340			for _, windowSize := range testWindowSizes {
341				t.Run(fmt.Sprintf("window:%d", windowSize), func(t *testing.T) {
342					e, err := NewWriter(nil, WithEncoderLevel(level), WithWindowSize(windowSize))
343					if err != nil {
344						t.Fatal(err)
345					}
346					defer e.Close()
347					start := time.Now()
348					dst := e.EncodeAll(in, nil)
349					t.Log("Simple Encoder len", len(in), "-> zstd len", len(dst))
350					mbpersec := (float64(len(in)) / (1024 * 1024)) / (float64(time.Since(start)) / (float64(time.Second)))
351					t.Logf("Encoded %d bytes with %.2f MB/s", len(in), mbpersec)
352
353					decoded, err := dec.DecodeAll(dst, nil)
354					if err != nil {
355						t.Error(err, len(decoded))
356					}
357					if !bytes.Equal(decoded, in) {
358						ioutil.WriteFile("testdata/"+t.Name()+"-pi.txt.got", decoded, os.ModePerm)
359						t.Fatal("Decoded does not match")
360					}
361					t.Log("Encoded content matched")
362				})
363			}
364		})
365	}
366}
367
368func TestWithEncoderPadding(t *testing.T) {
369	n := 100
370	if testing.Short() {
371		n = 2
372	}
373	rng := rand.New(rand.NewSource(0x1337))
374	d, err := NewReader(nil)
375	if err != nil {
376		t.Fatal(err)
377	}
378	defer d.Close()
379
380	for i := 0; i < n; i++ {
381		padding := (rng.Int() & 0xfff) + 1
382		src := make([]byte, (rng.Int()&0xfffff)+1)
383		for i := range src {
384			src[i] = uint8(rng.Uint32()) & 7
385		}
386		e, err := NewWriter(nil, WithEncoderPadding(padding), WithEncoderCRC(rng.Uint32()&1 == 0))
387		if err != nil {
388			t.Fatal(err)
389		}
390		// Test the added padding is invisible.
391		dst := e.EncodeAll(src, nil)
392		if len(dst)%padding != 0 {
393			t.Fatalf("wanted size to be mutiple of %d, got size %d with remainder %d", padding, len(dst), len(dst)%padding)
394		}
395		got, err := d.DecodeAll(dst, nil)
396		if err != nil {
397			t.Fatal(err)
398		}
399		if !bytes.Equal(src, got) {
400			t.Fatal("output mismatch")
401		}
402		// Test when we supply data as well.
403		dst = e.EncodeAll(src, make([]byte, rng.Int()&255))
404		if len(dst)%padding != 0 {
405			t.Fatalf("wanted size to be mutiple of %d, got size %d with remainder %d", padding, len(dst), len(dst)%padding)
406		}
407
408		// Test using the writer.
409		var buf bytes.Buffer
410		e.Reset(&buf)
411		_, err = io.Copy(e, bytes.NewBuffer(src))
412		if err != nil {
413			t.Fatal(err)
414		}
415		err = e.Close()
416		if err != nil {
417			t.Fatal(err)
418		}
419		dst = buf.Bytes()
420		if len(dst)%padding != 0 {
421			t.Fatalf("wanted size to be mutiple of %d, got size %d with remainder %d", padding, len(dst), len(dst)%padding)
422		}
423		// Test the added padding is invisible.
424		got, err = d.DecodeAll(dst, nil)
425		if err != nil {
426			t.Fatal(err)
427		}
428		if !bytes.Equal(src, got) {
429			t.Fatal("output mismatch")
430		}
431		// Try after reset
432		buf.Reset()
433		e.Reset(&buf)
434		_, err = io.Copy(e, bytes.NewBuffer(src))
435		if err != nil {
436			t.Fatal(err)
437		}
438		err = e.Close()
439		if err != nil {
440			t.Fatal(err)
441		}
442		dst = buf.Bytes()
443		if len(dst)%padding != 0 {
444			t.Fatalf("wanted size to be mutiple of %d, got size %d with remainder %d", padding, len(dst), len(dst)%padding)
445		}
446		// Test the added padding is invisible.
447		got, err = d.DecodeAll(dst, nil)
448		if err != nil {
449			t.Fatal(err)
450		}
451		if !bytes.Equal(src, got) {
452			t.Fatal("output mismatch")
453		}
454	}
455}
456func TestEncoder_EncoderXML(t *testing.T) {
457	testEncoderRoundtrip(t, "./testdata/xml.zst", []byte{0x56, 0x54, 0x69, 0x8e, 0x40, 0x50, 0x11, 0xe})
458	testEncoderRoundtripWriter(t, "./testdata/xml.zst", []byte{0x56, 0x54, 0x69, 0x8e, 0x40, 0x50, 0x11, 0xe})
459}
460
461func TestEncoder_EncoderTwain(t *testing.T) {
462	testEncoderRoundtrip(t, "../testdata/Mark.Twain-Tom.Sawyer.txt", []byte{0x12, 0x1f, 0x12, 0x70, 0x79, 0x37, 0x1f, 0xc6})
463	testEncoderRoundtripWriter(t, "../testdata/Mark.Twain-Tom.Sawyer.txt", []byte{0x12, 0x1f, 0x12, 0x70, 0x79, 0x37, 0x1f, 0xc6})
464}
465
466func TestEncoder_EncoderPi(t *testing.T) {
467	testEncoderRoundtrip(t, "../testdata/pi.txt", []byte{0xe7, 0xe5, 0x25, 0x39, 0x92, 0xc7, 0x4a, 0xfb})
468	testEncoderRoundtripWriter(t, "../testdata/pi.txt", []byte{0xe7, 0xe5, 0x25, 0x39, 0x92, 0xc7, 0x4a, 0xfb})
469}
470
471func TestEncoder_EncoderSilesia(t *testing.T) {
472	testEncoderRoundtrip(t, "testdata/silesia.tar", []byte{0xa5, 0x5b, 0x5e, 0xe, 0x5e, 0xea, 0x51, 0x6b})
473	testEncoderRoundtripWriter(t, "testdata/silesia.tar", []byte{0xa5, 0x5b, 0x5e, 0xe, 0x5e, 0xea, 0x51, 0x6b})
474}
475
476func TestEncoder_EncoderSimple(t *testing.T) {
477	testEncoderRoundtrip(t, "testdata/z000028", []byte{0x8b, 0x2, 0x37, 0x70, 0x92, 0xb, 0x98, 0x95})
478	testEncoderRoundtripWriter(t, "testdata/z000028", []byte{0x8b, 0x2, 0x37, 0x70, 0x92, 0xb, 0x98, 0x95})
479}
480
481func TestEncoder_EncoderHTML(t *testing.T) {
482	testEncoderRoundtrip(t, "../testdata/html.txt", []byte{0x35, 0xa9, 0x5c, 0x37, 0x20, 0x9e, 0xc3, 0x37})
483	testEncoderRoundtripWriter(t, "../testdata/html.txt", []byte{0x35, 0xa9, 0x5c, 0x37, 0x20, 0x9e, 0xc3, 0x37})
484}
485
486func TestEncoder_EncoderEnwik9(t *testing.T) {
487	testEncoderRoundtrip(t, "./testdata/enwik9.zst", []byte{0x28, 0xfa, 0xf4, 0x30, 0xca, 0x4b, 0x64, 0x12})
488	testEncoderRoundtripWriter(t, "./testdata/enwik9.zst", []byte{0x28, 0xfa, 0xf4, 0x30, 0xca, 0x4b, 0x64, 0x12})
489}
490
491// test roundtrip using io.ReaderFrom interface.
492func testEncoderRoundtrip(t *testing.T, file string, wantCRC []byte) {
493	for _, opt := range getEncOpts(1) {
494		t.Run(opt.name, func(t *testing.T) {
495			opt := opt
496			t.Parallel()
497			f, err := os.Open(file)
498			if err != nil {
499				if os.IsNotExist(err) {
500					t.Skip("No input file:", file)
501					return
502				}
503				t.Fatal(err)
504			}
505			defer f.Close()
506			if stat, err := f.Stat(); testing.Short() && err == nil {
507				if stat.Size() > 10000 {
508					t.SkipNow()
509				}
510			}
511			input := io.Reader(f)
512			if strings.HasSuffix(file, ".zst") {
513				dec, err := NewReader(f)
514				if err != nil {
515					t.Fatal(err)
516				}
517				input = dec
518				defer dec.Close()
519			}
520
521			pr, pw := io.Pipe()
522			dec2, err := NewReader(pr)
523			if err != nil {
524				t.Fatal(err)
525			}
526			defer dec2.Close()
527
528			enc, err := NewWriter(pw, opt.o...)
529			if err != nil {
530				t.Fatal(err)
531			}
532			defer enc.Close()
533			var wantSize int64
534			start := time.Now()
535			go func() {
536				n, err := enc.ReadFrom(input)
537				if err != nil {
538					t.Error(err)
539					return
540				}
541				wantSize = n
542				err = enc.Close()
543				if err != nil {
544					t.Error(err)
545					return
546				}
547				pw.Close()
548			}()
549			var gotSize int64
550
551			// Check CRC
552			d := xxhash.New()
553			if true {
554				gotSize, err = io.Copy(d, dec2)
555			} else {
556				fout, err := os.Create(file + ".got")
557				if err != nil {
558					t.Fatal(err)
559				}
560				gotSize, err = io.Copy(io.MultiWriter(fout, d), dec2)
561				if err != nil {
562					t.Fatal(err)
563				}
564			}
565			if wantSize != gotSize {
566				t.Errorf("want size (%d) != got size (%d)", wantSize, gotSize)
567			}
568			if err != nil {
569				t.Fatal(err)
570			}
571			if gotCRC := d.Sum(nil); len(wantCRC) > 0 && !bytes.Equal(gotCRC, wantCRC) {
572				t.Errorf("crc mismatch %#v (want) != %#v (got)", wantCRC, gotCRC)
573			} else if len(wantCRC) != 8 {
574				t.Logf("Unable to verify CRC: %#v", gotCRC)
575			} else {
576				t.Logf("CRC Verified: %#v", gotCRC)
577			}
578			t.Log("Encoder len", wantSize)
579			mbpersec := (float64(wantSize) / (1024 * 1024)) / (float64(time.Since(start)) / (float64(time.Second)))
580			t.Logf("Encoded+Decoded %d bytes with %.2f MB/s", wantSize, mbpersec)
581		})
582	}
583}
584
585type writerWrapper struct {
586	w io.Writer
587}
588
589func (w writerWrapper) Write(p []byte) (n int, err error) {
590	return w.w.Write(p)
591}
592
593// test roundtrip using plain io.Writer interface.
594func testEncoderRoundtripWriter(t *testing.T, file string, wantCRC []byte) {
595	f, err := os.Open(file)
596	if err != nil {
597		if os.IsNotExist(err) {
598			t.Skip("No input file:", file)
599			return
600		}
601		t.Fatal(err)
602	}
603	defer f.Close()
604	if stat, err := f.Stat(); testing.Short() && err == nil {
605		if stat.Size() > 10000 {
606			t.SkipNow()
607		}
608	}
609	input := io.Reader(f)
610	if strings.HasSuffix(file, ".zst") {
611		dec, err := NewReader(f)
612		if err != nil {
613			t.Fatal(err)
614		}
615		input = dec
616		defer dec.Close()
617	}
618
619	pr, pw := io.Pipe()
620	dec2, err := NewReader(pr)
621	if err != nil {
622		t.Fatal(err)
623	}
624	defer dec2.Close()
625
626	enc, err := NewWriter(pw, WithEncoderCRC(true))
627	if err != nil {
628		t.Fatal(err)
629	}
630	defer enc.Close()
631	encW := writerWrapper{w: enc}
632	var wantSize int64
633	start := time.Now()
634	go func() {
635		n, err := io.CopyBuffer(encW, input, make([]byte, 1337))
636		if err != nil {
637			t.Error(err)
638			return
639		}
640		wantSize = n
641		err = enc.Close()
642		if err != nil {
643			t.Error(err)
644			return
645		}
646		pw.Close()
647	}()
648	var gotSize int64
649
650	// Check CRC
651	d := xxhash.New()
652	if true {
653		gotSize, err = io.Copy(d, dec2)
654	} else {
655		fout, err := os.Create(file + ".got")
656		if err != nil {
657			t.Fatal(err)
658		}
659		gotSize, err = io.Copy(io.MultiWriter(fout, d), dec2)
660		if err != nil {
661			t.Fatal(err)
662		}
663	}
664	if wantSize != gotSize {
665		t.Errorf("want size (%d) != got size (%d)", wantSize, gotSize)
666	}
667	if err != nil {
668		t.Fatal(err)
669	}
670	if gotCRC := d.Sum(nil); len(wantCRC) > 0 && !bytes.Equal(gotCRC, wantCRC) {
671		t.Errorf("crc mismatch %#v (want) != %#v (got)", wantCRC, gotCRC)
672	} else if len(wantCRC) != 8 {
673		t.Logf("Unable to verify CRC: %#v", gotCRC)
674	} else {
675		t.Logf("CRC Verified: %#v", gotCRC)
676	}
677	t.Log("Fast Encoder len", wantSize)
678	mbpersec := (float64(wantSize) / (1024 * 1024)) / (float64(time.Since(start)) / (float64(time.Second)))
679	t.Logf("Encoded+Decoded %d bytes with %.2f MB/s", wantSize, mbpersec)
680}
681
682func TestEncoder_EncodeAllSilesia(t *testing.T) {
683	if testing.Short() {
684		t.SkipNow()
685	}
686	in, err := ioutil.ReadFile("testdata/silesia.tar")
687	if err != nil {
688		if os.IsNotExist(err) {
689			t.Skip("Missing testdata/silesia.tar")
690			return
691		}
692		t.Fatal(err)
693	}
694
695	var e Encoder
696	start := time.Now()
697	dst := e.EncodeAll(in, nil)
698	t.Log("Fast Encoder len", len(in), "-> zstd len", len(dst))
699	mbpersec := (float64(len(in)) / (1024 * 1024)) / (float64(time.Since(start)) / (float64(time.Second)))
700	t.Logf("Encoded %d bytes with %.2f MB/s", len(in), mbpersec)
701
702	dec, err := NewReader(nil, WithDecoderMaxMemory(220<<20))
703	if err != nil {
704		t.Fatal(err)
705	}
706	defer dec.Close()
707	decoded, err := dec.DecodeAll(dst, nil)
708	if err != nil {
709		t.Error(err, len(decoded))
710	}
711	if !bytes.Equal(decoded, in) {
712		ioutil.WriteFile("testdata/"+t.Name()+"-silesia.tar.got", decoded, os.ModePerm)
713		t.Fatal("Decoded does not match")
714	}
715	t.Log("Encoded content matched")
716}
717
718func TestEncoderReadFrom(t *testing.T) {
719	buffer := bytes.NewBuffer(nil)
720	encoder, err := NewWriter(buffer)
721	if err != nil {
722		t.Fatal(err)
723	}
724	if _, err := encoder.ReadFrom(strings.NewReader("0")); err != nil {
725		t.Fatal(err)
726	}
727	if err := encoder.Close(); err != nil {
728		t.Fatal(err)
729	}
730
731	dec, _ := NewReader(nil)
732	toDec := buffer.Bytes()
733	toDec = append(toDec, toDec...)
734	decoded, err := dec.DecodeAll(toDec, nil)
735	if err != nil {
736		t.Fatal(err)
737	}
738
739	if !bytes.Equal([]byte("00"), decoded) {
740		t.Logf("encoded: % x\n", buffer.Bytes())
741		t.Fatalf("output mismatch, got %s", string(decoded))
742	}
743	dec.Close()
744}
745
746func TestInterleavedWriteReadFrom(t *testing.T) {
747	var encoded bytes.Buffer
748
749	enc, err := NewWriter(&encoded)
750	if err != nil {
751		t.Fatal(err)
752	}
753
754	if _, err := enc.Write([]byte("write1")); err != nil {
755		t.Fatal(err)
756	}
757	if _, err := enc.Write([]byte("write2")); err != nil {
758		t.Fatal(err)
759	}
760	if _, err := enc.ReadFrom(strings.NewReader("readfrom1")); err != nil {
761		t.Fatal(err)
762	}
763	if _, err := enc.Write([]byte("write3")); err != nil {
764		t.Fatal(err)
765	}
766
767	if err := enc.Close(); err != nil {
768		t.Fatal(err)
769	}
770
771	dec, err := NewReader(&encoded)
772	if err != nil {
773		t.Fatal(err)
774	}
775	defer dec.Close()
776
777	gotb, err := ioutil.ReadAll(dec)
778	if err != nil {
779		t.Fatal(err)
780	}
781	got := string(gotb)
782
783	if want := "write1write2readfrom1write3"; got != want {
784		t.Errorf("got decoded %q, want %q", got, want)
785	}
786}
787
788func TestEncoder_EncodeAllEmpty(t *testing.T) {
789	if testing.Short() {
790		t.SkipNow()
791	}
792	var in []byte
793
794	for _, opt := range getEncOpts(1) {
795		t.Run(opt.name, func(t *testing.T) {
796			e, err := NewWriter(nil, opt.o...)
797			if err != nil {
798				t.Fatal(err)
799			}
800			defer e.Close()
801			dst := e.EncodeAll(in, nil)
802			t.Log("Block Encoder len", len(in), "-> zstd len", len(dst), dst)
803
804			dec, err := NewReader(nil, WithDecoderMaxMemory(220<<20))
805			if err != nil {
806				t.Fatal(err)
807			}
808			defer dec.Close()
809			decoded, err := dec.DecodeAll(dst, nil)
810			if err != nil {
811				t.Error(err, len(decoded))
812			}
813			if !bytes.Equal(decoded, in) {
814				t.Fatal("Decoded does not match")
815			}
816
817			// Test buffer writer.
818			var buf bytes.Buffer
819			e.Reset(&buf)
820			err = e.Close()
821			if err != nil {
822				t.Fatal(err)
823			}
824			dst = buf.Bytes()
825			t.Log("Buffer Encoder len", len(in), "-> zstd len", len(dst))
826
827			decoded, err = dec.DecodeAll(dst, nil)
828			if err != nil {
829				t.Error(err, len(decoded))
830			}
831			if !bytes.Equal(decoded, in) {
832				t.Fatal("Decoded does not match")
833			}
834
835			t.Log("Encoded content matched")
836		})
837	}
838}
839
840func TestEncoder_EncodeAllEnwik9(t *testing.T) {
841	if false || testing.Short() {
842		t.SkipNow()
843	}
844	file := "testdata/enwik9.zst"
845	f, err := os.Open(file)
846	if err != nil {
847		if os.IsNotExist(err) {
848			t.Skip("To run extended tests, download http://mattmahoney.net/dc/enwik9.zip unzip it \n" +
849				"compress it with 'zstd -15 -T0 enwik9' and place it in " + file)
850		}
851	}
852	dec, err := NewReader(f)
853	if err != nil {
854		t.Fatal(err)
855	}
856	defer dec.Close()
857	in, err := ioutil.ReadAll(dec)
858	if err != nil {
859		t.Fatal(err)
860	}
861
862	start := time.Now()
863	var e Encoder
864	dst := e.EncodeAll(in, nil)
865	t.Log("Simple Encoder len", len(in), "-> zstd len", len(dst))
866	mbpersec := (float64(len(in)) / (1024 * 1024)) / (float64(time.Since(start)) / (float64(time.Second)))
867	t.Logf("Encoded %d bytes with %.2f MB/s", len(in), mbpersec)
868	decoded, err := dec.DecodeAll(dst, nil)
869	if err != nil {
870		t.Error(err, len(decoded))
871	}
872	if !bytes.Equal(decoded, in) {
873		ioutil.WriteFile("testdata/"+t.Name()+"-enwik9.got", decoded, os.ModePerm)
874		t.Fatal("Decoded does not match")
875	}
876	t.Log("Encoded content matched")
877}
878
879func BenchmarkEncoder_EncodeAllXML(b *testing.B) {
880	f, err := os.Open("testdata/xml.zst")
881	if err != nil {
882		b.Fatal(err)
883	}
884	dec, err := NewReader(f)
885	if err != nil {
886		b.Fatal(err)
887	}
888	in, err := ioutil.ReadAll(dec)
889	if err != nil {
890		b.Fatal(err)
891	}
892	dec.Close()
893
894	enc := Encoder{}
895	dst := enc.EncodeAll(in, nil)
896	wantSize := len(dst)
897	b.Log("Output size:", len(dst))
898	b.ResetTimer()
899	b.ReportAllocs()
900	b.SetBytes(int64(len(in)))
901	for i := 0; i < b.N; i++ {
902		dst := enc.EncodeAll(in, dst[:0])
903		if len(dst) != wantSize {
904			b.Fatal(len(dst), "!=", wantSize)
905		}
906	}
907}
908
909func BenchmarkEncoder_EncodeAllSimple(b *testing.B) {
910	f, err := os.Open("testdata/z000028")
911	if err != nil {
912		b.Fatal(err)
913	}
914	in, err := ioutil.ReadAll(f)
915	if err != nil {
916		b.Fatal(err)
917	}
918
919	for level := speedNotSet + 1; level < speedLast; level++ {
920		b.Run(level.String(), func(b *testing.B) {
921			enc, err := NewWriter(nil, WithEncoderConcurrency(1), WithEncoderLevel(level))
922			if err != nil {
923				b.Fatal(err)
924			}
925			defer enc.Close()
926			dst := enc.EncodeAll(in, nil)
927			wantSize := len(dst)
928			b.ResetTimer()
929			b.ReportAllocs()
930			b.SetBytes(int64(len(in)))
931			for i := 0; i < b.N; i++ {
932				dst := enc.EncodeAll(in, dst[:0])
933				if len(dst) != wantSize {
934					b.Fatal(len(dst), "!=", wantSize)
935				}
936			}
937		})
938	}
939}
940
941func BenchmarkEncoder_EncodeAllSimple4K(b *testing.B) {
942	f, err := os.Open("testdata/z000028")
943	if err != nil {
944		b.Fatal(err)
945	}
946	in, err := ioutil.ReadAll(f)
947	if err != nil {
948		b.Fatal(err)
949	}
950	in = in[:4096]
951
952	for level := speedNotSet + 1; level < speedLast; level++ {
953		b.Run(level.String(), func(b *testing.B) {
954			enc, err := NewWriter(nil, WithEncoderConcurrency(1), WithEncoderLevel(level))
955			if err != nil {
956				b.Fatal(err)
957			}
958			defer enc.Close()
959			dst := enc.EncodeAll(in, nil)
960			wantSize := len(dst)
961			b.ResetTimer()
962			b.ReportAllocs()
963			b.SetBytes(int64(len(in)))
964			for i := 0; i < b.N; i++ {
965				dst := enc.EncodeAll(in, dst[:0])
966				if len(dst) != wantSize {
967					b.Fatal(len(dst), "!=", wantSize)
968				}
969			}
970		})
971	}
972}
973
974func BenchmarkEncoder_EncodeAllHTML(b *testing.B) {
975	f, err := os.Open("../testdata/html.txt")
976	if err != nil {
977		b.Fatal(err)
978	}
979	in, err := ioutil.ReadAll(f)
980	if err != nil {
981		b.Fatal(err)
982	}
983
984	enc := Encoder{}
985	dst := enc.EncodeAll(in, nil)
986	wantSize := len(dst)
987	b.ResetTimer()
988	b.ReportAllocs()
989	b.SetBytes(int64(len(in)))
990	for i := 0; i < b.N; i++ {
991		dst := enc.EncodeAll(in, dst[:0])
992		if len(dst) != wantSize {
993			b.Fatal(len(dst), "!=", wantSize)
994		}
995	}
996}
997
998func BenchmarkEncoder_EncodeAllTwain(b *testing.B) {
999	f, err := os.Open("../testdata/Mark.Twain-Tom.Sawyer.txt")
1000	if err != nil {
1001		b.Fatal(err)
1002	}
1003	in, err := ioutil.ReadAll(f)
1004	if err != nil {
1005		b.Fatal(err)
1006	}
1007
1008	enc := Encoder{}
1009	dst := enc.EncodeAll(in, nil)
1010	wantSize := len(dst)
1011	b.ResetTimer()
1012	b.ReportAllocs()
1013	b.SetBytes(int64(len(in)))
1014	for i := 0; i < b.N; i++ {
1015		dst := enc.EncodeAll(in, dst[:0])
1016		if len(dst) != wantSize {
1017			b.Fatal(len(dst), "!=", wantSize)
1018		}
1019	}
1020}
1021
1022func BenchmarkEncoder_EncodeAllPi(b *testing.B) {
1023	f, err := os.Open("../testdata/pi.txt")
1024	if err != nil {
1025		b.Fatal(err)
1026	}
1027	in, err := ioutil.ReadAll(f)
1028	if err != nil {
1029		b.Fatal(err)
1030	}
1031
1032	enc := Encoder{}
1033	dst := enc.EncodeAll(in, nil)
1034	wantSize := len(dst)
1035	b.ResetTimer()
1036	b.ReportAllocs()
1037	b.SetBytes(int64(len(in)))
1038	for i := 0; i < b.N; i++ {
1039		dst := enc.EncodeAll(in, dst[:0])
1040		if len(dst) != wantSize {
1041			b.Fatal(len(dst), "!=", wantSize)
1042		}
1043	}
1044}
1045
1046func BenchmarkRandom4KEncodeAllFastest(b *testing.B) {
1047	rng := rand.New(rand.NewSource(1))
1048	data := make([]byte, 4<<10)
1049	for i := range data {
1050		data[i] = uint8(rng.Intn(256))
1051	}
1052	enc, _ := NewWriter(nil, WithEncoderLevel(SpeedFastest), WithEncoderConcurrency(1))
1053	defer enc.Close()
1054	dst := enc.EncodeAll(data, nil)
1055	wantSize := len(dst)
1056	b.ResetTimer()
1057	b.ReportAllocs()
1058	b.SetBytes(int64(len(data)))
1059	for i := 0; i < b.N; i++ {
1060		dst := enc.EncodeAll(data, dst[:0])
1061		if len(dst) != wantSize {
1062			b.Fatal(len(dst), "!=", wantSize)
1063		}
1064	}
1065}
1066
1067func BenchmarkRandom10MBEncodeAllFastest(b *testing.B) {
1068	rng := rand.New(rand.NewSource(1))
1069	data := make([]byte, 10<<20)
1070	rng.Read(data)
1071	enc, _ := NewWriter(nil, WithEncoderLevel(SpeedFastest), WithEncoderConcurrency(2))
1072	defer enc.Close()
1073	dst := enc.EncodeAll(data, nil)
1074	wantSize := len(dst)
1075	b.ResetTimer()
1076	b.ReportAllocs()
1077	b.SetBytes(int64(len(data)))
1078	for i := 0; i < b.N; i++ {
1079		dst := enc.EncodeAll(data, dst[:0])
1080		if len(dst) != wantSize {
1081			b.Fatal(len(dst), "!=", wantSize)
1082		}
1083	}
1084}
1085
1086func BenchmarkRandom4KEncodeAllDefault(b *testing.B) {
1087	rng := rand.New(rand.NewSource(1))
1088	data := make([]byte, 4<<10)
1089	rng.Read(data)
1090	enc, _ := NewWriter(nil, WithEncoderLevel(SpeedDefault), WithEncoderConcurrency(1))
1091	defer enc.Close()
1092	dst := enc.EncodeAll(data, nil)
1093	wantSize := len(dst)
1094	b.ResetTimer()
1095	b.ReportAllocs()
1096	b.SetBytes(int64(len(data)))
1097	for i := 0; i < b.N; i++ {
1098		dst := enc.EncodeAll(data, dst[:0])
1099		if len(dst) != wantSize {
1100			b.Fatal(len(dst), "!=", wantSize)
1101		}
1102	}
1103}
1104
1105func BenchmarkRandomEncodeAllDefault(b *testing.B) {
1106	rng := rand.New(rand.NewSource(1))
1107	data := make([]byte, 10<<20)
1108	rng.Read(data)
1109	enc, _ := NewWriter(nil, WithEncoderLevel(SpeedDefault), WithEncoderConcurrency(1))
1110	defer enc.Close()
1111	dst := enc.EncodeAll(data, nil)
1112	wantSize := len(dst)
1113	b.ResetTimer()
1114	b.ReportAllocs()
1115	b.SetBytes(int64(len(data)))
1116	for i := 0; i < b.N; i++ {
1117		dst := enc.EncodeAll(data, dst[:0])
1118		if len(dst) != wantSize {
1119			b.Fatal(len(dst), "!=", wantSize)
1120		}
1121	}
1122}
1123
1124func BenchmarkRandom10MBEncoderFastest(b *testing.B) {
1125	rng := rand.New(rand.NewSource(1))
1126	data := make([]byte, 10<<20)
1127	rng.Read(data)
1128	wantSize := int64(len(data))
1129	enc, _ := NewWriter(ioutil.Discard, WithEncoderLevel(SpeedFastest))
1130	defer enc.Close()
1131	n, err := io.Copy(enc, bytes.NewBuffer(data))
1132	if err != nil {
1133		b.Fatal(err)
1134	}
1135	if n != wantSize {
1136		b.Fatal(n, "!=", wantSize)
1137	}
1138	b.ResetTimer()
1139	b.ReportAllocs()
1140	b.SetBytes(wantSize)
1141	for i := 0; i < b.N; i++ {
1142		enc.Reset(ioutil.Discard)
1143		n, err := io.Copy(enc, bytes.NewBuffer(data))
1144		if err != nil {
1145			b.Fatal(err)
1146		}
1147		if n != wantSize {
1148			b.Fatal(n, "!=", wantSize)
1149		}
1150	}
1151}
1152
1153func BenchmarkRandomEncoderDefault(b *testing.B) {
1154	rng := rand.New(rand.NewSource(1))
1155	data := make([]byte, 10<<20)
1156	rng.Read(data)
1157	wantSize := int64(len(data))
1158	enc, _ := NewWriter(ioutil.Discard, WithEncoderLevel(SpeedDefault))
1159	defer enc.Close()
1160	n, err := io.Copy(enc, bytes.NewBuffer(data))
1161	if err != nil {
1162		b.Fatal(err)
1163	}
1164	if n != wantSize {
1165		b.Fatal(n, "!=", wantSize)
1166	}
1167	b.ResetTimer()
1168	b.ReportAllocs()
1169	b.SetBytes(wantSize)
1170	for i := 0; i < b.N; i++ {
1171		enc.Reset(ioutil.Discard)
1172		n, err := io.Copy(enc, bytes.NewBuffer(data))
1173		if err != nil {
1174			b.Fatal(err)
1175		}
1176		if n != wantSize {
1177			b.Fatal(n, "!=", wantSize)
1178		}
1179	}
1180}
1181