1// Copyright 2019+ Klaus Post. All rights reserved.
2// License information can be found in the LICENSE file.
3// Based on work by Yann Collet, released under BSD License.
4
5package zstd
6
7import (
8	"bytes"
9	"fmt"
10	"io"
11	"io/ioutil"
12	"math/rand"
13	"os"
14	"runtime"
15	"strings"
16	"sync"
17	"testing"
18	"time"
19
20	"github.com/klauspost/compress/zip"
21	"github.com/klauspost/compress/zstd/internal/xxhash"
22)
23
24var testWindowSizes = []int{MinWindowSize, 1 << 16, 1 << 22, 1 << 24}
25
26func TestEncoder_EncodeAllSimple(t *testing.T) {
27	in, err := ioutil.ReadFile("testdata/z000028")
28	if err != nil {
29		t.Fatal(err)
30	}
31	dec, err := NewReader(nil)
32	if err != nil {
33		t.Fatal(err)
34	}
35	defer dec.Close()
36
37	in = append(in, in...)
38	for level := EncoderLevel(speedNotSet + 1); level < speedLast; level++ {
39		t.Run(level.String(), func(t *testing.T) {
40			e, err := NewWriter(nil, WithEncoderLevel(level), WithEncoderConcurrency(2), WithWindowSize(128<<10), WithZeroFrames(true))
41			if err != nil {
42				t.Fatal(err)
43			}
44			defer e.Close()
45			start := time.Now()
46			dst := e.EncodeAll(in, nil)
47			t.Log("Simple Encoder len", len(in), "-> zstd len", len(dst))
48			mbpersec := (float64(len(in)) / (1024 * 1024)) / (float64(time.Since(start)) / (float64(time.Second)))
49			t.Logf("Encoded %d bytes with %.2f MB/s", len(in), mbpersec)
50
51			decoded, err := dec.DecodeAll(dst, nil)
52			if err != nil {
53				t.Error(err, len(decoded))
54			}
55			if !bytes.Equal(decoded, in) {
56				ioutil.WriteFile("testdata/"+t.Name()+"-z000028.got", decoded, os.ModePerm)
57				ioutil.WriteFile("testdata/"+t.Name()+"-z000028.want", in, os.ModePerm)
58				t.Fatal("Decoded does not match")
59			}
60			t.Log("Encoded content matched")
61		})
62	}
63}
64
65func TestEncoder_EncodeAllConcurrent(t *testing.T) {
66	in, err := ioutil.ReadFile("testdata/z000028")
67	if err != nil {
68		t.Fatal(err)
69	}
70	in = append(in, in...)
71
72	// When running race no more than 8k goroutines allowed.
73	n := 4000 / runtime.GOMAXPROCS(0)
74	if testing.Short() {
75		n = 200 / runtime.GOMAXPROCS(0)
76	}
77	dec, err := NewReader(nil)
78	if err != nil {
79		t.Fatal(err)
80	}
81	defer dec.Close()
82	for level := EncoderLevel(speedNotSet + 1); level < speedLast; level++ {
83		t.Run(level.String(), func(t *testing.T) {
84			rng := rand.New(rand.NewSource(0x1337))
85			e, err := NewWriter(nil, WithEncoderLevel(level), WithZeroFrames(true))
86			if err != nil {
87				t.Fatal(err)
88			}
89			defer e.Close()
90			var wg sync.WaitGroup
91			wg.Add(n)
92			for i := 0; i < n; i++ {
93				in := in[rng.Int()&1023:]
94				in = in[:rng.Intn(len(in))]
95				go func() {
96					defer wg.Done()
97					dst := e.EncodeAll(in, nil)
98					//t.Log("Simple Encoder len", len(in), "-> zstd len", len(dst))
99					decoded, err := dec.DecodeAll(dst, nil)
100					if err != nil {
101						t.Error(err, len(decoded))
102					}
103					if !bytes.Equal(decoded, in) {
104						//ioutil.WriteFile("testdata/"+t.Name()+"-z000028.got", decoded, os.ModePerm)
105						//ioutil.WriteFile("testdata/"+t.Name()+"-z000028.want", in, os.ModePerm)
106						t.Fatal("Decoded does not match")
107					}
108				}()
109			}
110			wg.Wait()
111			t.Log("Encoded content matched.", n, "goroutines")
112		})
113	}
114}
115
116func TestEncoder_EncodeAllEncodeXML(t *testing.T) {
117	f, err := os.Open("testdata/xml.zst")
118	if err != nil {
119		t.Fatal(err)
120	}
121	dec, err := NewReader(f)
122	if err != nil {
123		t.Fatal(err)
124	}
125	defer dec.Close()
126	in, err := ioutil.ReadAll(dec)
127	if err != nil {
128		t.Fatal(err)
129	}
130
131	for level := EncoderLevel(speedNotSet + 1); level < speedLast; level++ {
132		t.Run(level.String(), func(t *testing.T) {
133			e, err := NewWriter(nil, WithEncoderLevel(level))
134			if err != nil {
135				t.Fatal(err)
136			}
137			defer e.Close()
138			start := time.Now()
139			dst := e.EncodeAll(in, nil)
140			t.Log("Simple Encoder len", len(in), "-> zstd len", len(dst))
141			mbpersec := (float64(len(in)) / (1024 * 1024)) / (float64(time.Since(start)) / (float64(time.Second)))
142			t.Logf("Encoded %d bytes with %.2f MB/s", len(in), mbpersec)
143
144			decoded, err := dec.DecodeAll(dst, nil)
145			if err != nil {
146				t.Error(err, len(decoded))
147			}
148			if !bytes.Equal(decoded, in) {
149				ioutil.WriteFile("testdata/"+t.Name()+"-xml.got", decoded, os.ModePerm)
150				t.Fatal("Decoded does not match")
151			}
152			t.Log("Encoded content matched")
153		})
154	}
155}
156
157func TestEncoderRegression(t *testing.T) {
158	defer timeout(2 * time.Minute)()
159	data, err := ioutil.ReadFile("testdata/comp-crashers.zip")
160	if err != nil {
161		t.Fatal(err)
162	}
163	// We can't close the decoder.
164	dec, err := NewReader(nil)
165	if err != nil {
166		t.Error(err)
167		return
168	}
169	defer dec.Close()
170	testWindowSizes := testWindowSizes
171	if testing.Short() {
172		testWindowSizes = []int{1 << 20}
173	}
174	for level := EncoderLevel(speedNotSet + 1); level < speedLast; level++ {
175		t.Run(level.String(), func(t *testing.T) {
176			for _, windowSize := range testWindowSizes {
177				t.Run(fmt.Sprintf("window:%d", windowSize), func(t *testing.T) {
178					zr, err := zip.NewReader(bytes.NewReader(data), int64(len(data)))
179					if err != nil {
180						t.Fatal(err)
181					}
182					enc, err := NewWriter(
183						nil,
184						WithEncoderCRC(true),
185						WithEncoderLevel(level),
186						WithWindowSize(windowSize),
187					)
188					if err != nil {
189						t.Fatal(err)
190					}
191					defer enc.Close()
192
193					for i, tt := range zr.File {
194						if !strings.HasSuffix(t.Name(), "") {
195							continue
196						}
197						if testing.Short() && i > 100 {
198							break
199						}
200
201						t.Run(tt.Name, func(t *testing.T) {
202							r, err := tt.Open()
203							if err != nil {
204								t.Error(err)
205								return
206							}
207							in, err := ioutil.ReadAll(r)
208							if err != nil {
209								t.Error(err)
210							}
211							encoded := enc.EncodeAll(in, nil)
212							got, err := dec.DecodeAll(encoded, nil)
213							if err != nil {
214								t.Logf("error: %v\nwant: %v\ngot:  %v", err, len(in), len(got))
215								t.Fatal(err)
216							}
217
218							// Use the Writer
219							var dst bytes.Buffer
220							enc.Reset(&dst)
221							_, err = enc.Write(in)
222							if err != nil {
223								t.Error(err)
224							}
225							err = enc.Close()
226							if err != nil {
227								t.Error(err)
228							}
229							encoded = dst.Bytes()
230							got, err = dec.DecodeAll(encoded, nil)
231							if err != nil {
232								t.Logf("error: %v\nwant: %v\ngot:  %v", err, in, got)
233								t.Fatal(err)
234							}
235
236						})
237					}
238				})
239			}
240		})
241	}
242}
243
244func TestEncoder_EncodeAllTwain(t *testing.T) {
245	in, err := ioutil.ReadFile("../testdata/Mark.Twain-Tom.Sawyer.txt")
246	if err != nil {
247		t.Fatal(err)
248	}
249	testWindowSizes := testWindowSizes
250	if testing.Short() {
251		testWindowSizes = []int{1 << 20}
252	}
253
254	dec, err := NewReader(nil)
255	if err != nil {
256		t.Fatal(err)
257	}
258	defer dec.Close()
259
260	for level := EncoderLevel(speedNotSet + 1); level < speedLast; level++ {
261		t.Run(level.String(), func(t *testing.T) {
262			for _, windowSize := range testWindowSizes {
263				t.Run(fmt.Sprintf("window:%d", windowSize), func(t *testing.T) {
264					e, err := NewWriter(nil, WithEncoderLevel(level), WithWindowSize(windowSize))
265					if err != nil {
266						t.Fatal(err)
267					}
268					defer e.Close()
269					start := time.Now()
270					dst := e.EncodeAll(in, nil)
271					t.Log("Simple Encoder len", len(in), "-> zstd len", len(dst))
272					mbpersec := (float64(len(in)) / (1024 * 1024)) / (float64(time.Since(start)) / (float64(time.Second)))
273					t.Logf("Encoded %d bytes with %.2f MB/s", len(in), mbpersec)
274
275					decoded, err := dec.DecodeAll(dst, nil)
276					if err != nil {
277						t.Error(err, len(decoded))
278					}
279					if !bytes.Equal(decoded, in) {
280						ioutil.WriteFile("testdata/"+t.Name()+"-Mark.Twain-Tom.Sawyer.txt.got", decoded, os.ModePerm)
281						t.Fatal("Decoded does not match")
282					}
283					t.Log("Encoded content matched")
284				})
285			}
286		})
287	}
288}
289
290func TestEncoder_EncodeAllPi(t *testing.T) {
291	in, err := ioutil.ReadFile("../testdata/pi.txt")
292	if err != nil {
293		t.Fatal(err)
294	}
295	testWindowSizes := testWindowSizes
296	if testing.Short() {
297		testWindowSizes = []int{1 << 20}
298	}
299
300	dec, err := NewReader(nil)
301	if err != nil {
302		t.Fatal(err)
303	}
304	defer dec.Close()
305
306	for level := EncoderLevel(speedNotSet + 1); level < speedLast; level++ {
307		t.Run(level.String(), func(t *testing.T) {
308			for _, windowSize := range testWindowSizes {
309				t.Run(fmt.Sprintf("window:%d", windowSize), func(t *testing.T) {
310					e, err := NewWriter(nil, WithEncoderLevel(level), WithWindowSize(windowSize))
311					if err != nil {
312						t.Fatal(err)
313					}
314					defer e.Close()
315					start := time.Now()
316					dst := e.EncodeAll(in, nil)
317					t.Log("Simple Encoder len", len(in), "-> zstd len", len(dst))
318					mbpersec := (float64(len(in)) / (1024 * 1024)) / (float64(time.Since(start)) / (float64(time.Second)))
319					t.Logf("Encoded %d bytes with %.2f MB/s", len(in), mbpersec)
320
321					decoded, err := dec.DecodeAll(dst, nil)
322					if err != nil {
323						t.Error(err, len(decoded))
324					}
325					if !bytes.Equal(decoded, in) {
326						ioutil.WriteFile("testdata/"+t.Name()+"-pi.txt.got", decoded, os.ModePerm)
327						t.Fatal("Decoded does not match")
328					}
329					t.Log("Encoded content matched")
330				})
331			}
332		})
333	}
334}
335
336func TestWithEncoderPadding(t *testing.T) {
337	n := 100
338	if testing.Short() {
339		n = 5
340	}
341	rng := rand.New(rand.NewSource(0x1337))
342	d, err := NewReader(nil)
343	if err != nil {
344		t.Fatal(err)
345	}
346	defer d.Close()
347
348	for i := 0; i < n; i++ {
349		padding := (rng.Int() & 0xfff) + 1
350		src := make([]byte, (rng.Int()&0xfffff)+1)
351		for i := range src {
352			src[i] = uint8(rng.Uint32()) & 7
353		}
354		e, err := NewWriter(nil, WithEncoderPadding(padding), WithEncoderCRC(rng.Uint32()&1 == 0))
355		if err != nil {
356			t.Fatal(err)
357		}
358		// Test the added padding is invisible.
359		dst := e.EncodeAll(src, nil)
360		if len(dst)%padding != 0 {
361			t.Fatalf("wanted size to be mutiple of %d, got size %d with remainder %d", padding, len(dst), len(dst)%padding)
362		}
363		got, err := d.DecodeAll(dst, nil)
364		if err != nil {
365			t.Fatal(err)
366		}
367		if !bytes.Equal(src, got) {
368			t.Fatal("output mismatch")
369		}
370		// Test when we supply data as well.
371		dst = e.EncodeAll(src, make([]byte, rng.Int()&255))
372		if len(dst)%padding != 0 {
373			t.Fatalf("wanted size to be mutiple of %d, got size %d with remainder %d", padding, len(dst), len(dst)%padding)
374		}
375
376		// Test using the writer.
377		var buf bytes.Buffer
378		e.Reset(&buf)
379		_, err = io.Copy(e, bytes.NewBuffer(src))
380		if err != nil {
381			t.Fatal(err)
382		}
383		err = e.Close()
384		if err != nil {
385			t.Fatal(err)
386		}
387		dst = buf.Bytes()
388		if len(dst)%padding != 0 {
389			t.Fatalf("wanted size to be mutiple of %d, got size %d with remainder %d", padding, len(dst), len(dst)%padding)
390		}
391		// Test the added padding is invisible.
392		got, err = d.DecodeAll(dst, nil)
393		if err != nil {
394			t.Fatal(err)
395		}
396		if !bytes.Equal(src, got) {
397			t.Fatal("output mismatch")
398		}
399		// Try after reset
400		buf.Reset()
401		e.Reset(&buf)
402		_, err = io.Copy(e, bytes.NewBuffer(src))
403		if err != nil {
404			t.Fatal(err)
405		}
406		err = e.Close()
407		if err != nil {
408			t.Fatal(err)
409		}
410		dst = buf.Bytes()
411		if len(dst)%padding != 0 {
412			t.Fatalf("wanted size to be mutiple of %d, got size %d with remainder %d", padding, len(dst), len(dst)%padding)
413		}
414		// Test the added padding is invisible.
415		got, err = d.DecodeAll(dst, nil)
416		if err != nil {
417			t.Fatal(err)
418		}
419		if !bytes.Equal(src, got) {
420			t.Fatal("output mismatch")
421		}
422	}
423}
424func TestEncoder_EncoderXML(t *testing.T) {
425	testEncoderRoundtrip(t, "./testdata/xml.zst", []byte{0x56, 0x54, 0x69, 0x8e, 0x40, 0x50, 0x11, 0xe})
426	testEncoderRoundtripWriter(t, "./testdata/xml.zst", []byte{0x56, 0x54, 0x69, 0x8e, 0x40, 0x50, 0x11, 0xe})
427}
428
429func TestEncoder_EncoderTwain(t *testing.T) {
430	testEncoderRoundtrip(t, "../testdata/Mark.Twain-Tom.Sawyer.txt", []byte{0x12, 0x1f, 0x12, 0x70, 0x79, 0x37, 0x1f, 0xc6})
431	testEncoderRoundtripWriter(t, "../testdata/Mark.Twain-Tom.Sawyer.txt", []byte{0x12, 0x1f, 0x12, 0x70, 0x79, 0x37, 0x1f, 0xc6})
432}
433
434func TestEncoder_EncoderPi(t *testing.T) {
435	testEncoderRoundtrip(t, "../testdata/pi.txt", []byte{0xe7, 0xe5, 0x25, 0x39, 0x92, 0xc7, 0x4a, 0xfb})
436	testEncoderRoundtripWriter(t, "../testdata/pi.txt", []byte{0xe7, 0xe5, 0x25, 0x39, 0x92, 0xc7, 0x4a, 0xfb})
437}
438
439func TestEncoder_EncoderSilesia(t *testing.T) {
440	testEncoderRoundtrip(t, "testdata/silesia.tar", []byte{0xa5, 0x5b, 0x5e, 0xe, 0x5e, 0xea, 0x51, 0x6b})
441	testEncoderRoundtripWriter(t, "testdata/silesia.tar", []byte{0xa5, 0x5b, 0x5e, 0xe, 0x5e, 0xea, 0x51, 0x6b})
442}
443
444func TestEncoder_EncoderSimple(t *testing.T) {
445	testEncoderRoundtrip(t, "testdata/z000028", []byte{0x8b, 0x2, 0x37, 0x70, 0x92, 0xb, 0x98, 0x95})
446	testEncoderRoundtripWriter(t, "testdata/z000028", []byte{0x8b, 0x2, 0x37, 0x70, 0x92, 0xb, 0x98, 0x95})
447}
448
449func TestEncoder_EncoderHTML(t *testing.T) {
450	testEncoderRoundtrip(t, "../testdata/html.txt", []byte{0x35, 0xa9, 0x5c, 0x37, 0x20, 0x9e, 0xc3, 0x37})
451	testEncoderRoundtripWriter(t, "../testdata/html.txt", []byte{0x35, 0xa9, 0x5c, 0x37, 0x20, 0x9e, 0xc3, 0x37})
452}
453
454func TestEncoder_EncoderEnwik9(t *testing.T) {
455	testEncoderRoundtrip(t, "./testdata/enwik9.zst", []byte{0x28, 0xfa, 0xf4, 0x30, 0xca, 0x4b, 0x64, 0x12})
456	testEncoderRoundtripWriter(t, "./testdata/enwik9.zst", []byte{0x28, 0xfa, 0xf4, 0x30, 0xca, 0x4b, 0x64, 0x12})
457}
458
459// test roundtrip using io.ReaderFrom interface.
460func testEncoderRoundtrip(t *testing.T, file string, wantCRC []byte) {
461	for level := speedNotSet + 1; level < speedLast; level++ {
462		t.Run(level.String(), func(t *testing.T) {
463			level := level
464			t.Parallel()
465			f, err := os.Open(file)
466			if err != nil {
467				if os.IsNotExist(err) {
468					t.Skip("No input file:", file)
469					return
470				}
471				t.Fatal(err)
472			}
473			defer f.Close()
474			input := io.Reader(f)
475			if strings.HasSuffix(file, ".zst") {
476				dec, err := NewReader(f)
477				if err != nil {
478					t.Fatal(err)
479				}
480				input = dec
481				defer dec.Close()
482			}
483
484			pr, pw := io.Pipe()
485			dec2, err := NewReader(pr)
486			if err != nil {
487				t.Fatal(err)
488			}
489			defer dec2.Close()
490
491			enc, err := NewWriter(pw, WithEncoderCRC(true), WithEncoderLevel(level))
492			if err != nil {
493				t.Fatal(err)
494			}
495			defer enc.Close()
496			var wantSize int64
497			start := time.Now()
498			go func() {
499				n, err := enc.ReadFrom(input)
500				if err != nil {
501					t.Fatal(err)
502				}
503				wantSize = n
504				err = enc.Close()
505				if err != nil {
506					t.Fatal(err)
507				}
508				pw.Close()
509			}()
510			var gotSize int64
511
512			// Check CRC
513			d := xxhash.New()
514			if true {
515				gotSize, err = io.Copy(d, dec2)
516			} else {
517				fout, err := os.Create(file + ".got")
518				if err != nil {
519					t.Fatal(err)
520				}
521				gotSize, err = io.Copy(io.MultiWriter(fout, d), dec2)
522				if err != nil {
523					t.Fatal(err)
524				}
525			}
526			if wantSize != gotSize {
527				t.Errorf("want size (%d) != got size (%d)", wantSize, gotSize)
528			}
529			if err != nil {
530				t.Fatal(err)
531			}
532			if gotCRC := d.Sum(nil); len(wantCRC) > 0 && !bytes.Equal(gotCRC, wantCRC) {
533				t.Errorf("crc mismatch %#v (want) != %#v (got)", wantCRC, gotCRC)
534			} else if len(wantCRC) != 8 {
535				t.Logf("Unable to verify CRC: %#v", gotCRC)
536			} else {
537				t.Logf("CRC Verified: %#v", gotCRC)
538			}
539			t.Log("Encoder len", wantSize)
540			mbpersec := (float64(wantSize) / (1024 * 1024)) / (float64(time.Since(start)) / (float64(time.Second)))
541			t.Logf("Encoded+Decoded %d bytes with %.2f MB/s", wantSize, mbpersec)
542		})
543	}
544}
545
546type writerWrapper struct {
547	w io.Writer
548}
549
550func (w writerWrapper) Write(p []byte) (n int, err error) {
551	return w.w.Write(p)
552}
553
554// test roundtrip using plain io.Writer interface.
555func testEncoderRoundtripWriter(t *testing.T, file string, wantCRC []byte) {
556	f, err := os.Open(file)
557	if err != nil {
558		if os.IsNotExist(err) {
559			t.Skip("No input file:", file)
560			return
561		}
562		t.Fatal(err)
563	}
564	defer f.Close()
565	input := io.Reader(f)
566	if strings.HasSuffix(file, ".zst") {
567		dec, err := NewReader(f)
568		if err != nil {
569			t.Fatal(err)
570		}
571		input = dec
572		defer dec.Close()
573	}
574
575	pr, pw := io.Pipe()
576	dec2, err := NewReader(pr)
577	if err != nil {
578		t.Fatal(err)
579	}
580	defer dec2.Close()
581
582	enc, err := NewWriter(pw, WithEncoderCRC(true))
583	if err != nil {
584		t.Fatal(err)
585	}
586	defer enc.Close()
587	encW := writerWrapper{w: enc}
588	var wantSize int64
589	start := time.Now()
590	go func() {
591		n, err := io.CopyBuffer(encW, input, make([]byte, 1337))
592		if err != nil {
593			t.Fatal(err)
594		}
595		wantSize = n
596		err = enc.Close()
597		if err != nil {
598			t.Fatal(err)
599		}
600		pw.Close()
601	}()
602	var gotSize int64
603
604	// Check CRC
605	d := xxhash.New()
606	if true {
607		gotSize, err = io.Copy(d, dec2)
608	} else {
609		fout, err := os.Create(file + ".got")
610		if err != nil {
611			t.Fatal(err)
612		}
613		gotSize, err = io.Copy(io.MultiWriter(fout, d), dec2)
614		if err != nil {
615			t.Fatal(err)
616		}
617	}
618	if wantSize != gotSize {
619		t.Errorf("want size (%d) != got size (%d)", wantSize, gotSize)
620	}
621	if err != nil {
622		t.Fatal(err)
623	}
624	if gotCRC := d.Sum(nil); len(wantCRC) > 0 && !bytes.Equal(gotCRC, wantCRC) {
625		t.Errorf("crc mismatch %#v (want) != %#v (got)", wantCRC, gotCRC)
626	} else if len(wantCRC) != 8 {
627		t.Logf("Unable to verify CRC: %#v", gotCRC)
628	} else {
629		t.Logf("CRC Verified: %#v", gotCRC)
630	}
631	t.Log("Fast Encoder len", wantSize)
632	mbpersec := (float64(wantSize) / (1024 * 1024)) / (float64(time.Since(start)) / (float64(time.Second)))
633	t.Logf("Encoded+Decoded %d bytes with %.2f MB/s", wantSize, mbpersec)
634}
635
636func TestEncoder_EncodeAllSilesia(t *testing.T) {
637	if testing.Short() {
638		t.SkipNow()
639	}
640	in, err := ioutil.ReadFile("testdata/silesia.tar")
641	if err != nil {
642		if os.IsNotExist(err) {
643			t.Skip("Missing testdata/silesia.tar")
644			return
645		}
646		t.Fatal(err)
647	}
648
649	var e Encoder
650	start := time.Now()
651	dst := e.EncodeAll(in, nil)
652	t.Log("Fast Encoder len", len(in), "-> zstd len", len(dst))
653	mbpersec := (float64(len(in)) / (1024 * 1024)) / (float64(time.Since(start)) / (float64(time.Second)))
654	t.Logf("Encoded %d bytes with %.2f MB/s", len(in), mbpersec)
655
656	dec, err := NewReader(nil, WithDecoderMaxMemory(220<<20))
657	if err != nil {
658		t.Fatal(err)
659	}
660	defer dec.Close()
661	decoded, err := dec.DecodeAll(dst, nil)
662	if err != nil {
663		t.Error(err, len(decoded))
664	}
665	if !bytes.Equal(decoded, in) {
666		ioutil.WriteFile("testdata/"+t.Name()+"-silesia.tar.got", decoded, os.ModePerm)
667		t.Fatal("Decoded does not match")
668	}
669	t.Log("Encoded content matched")
670}
671
672func TestEncoderReadFrom(t *testing.T) {
673	buffer := bytes.NewBuffer(nil)
674	encoder, err := NewWriter(buffer)
675	if err != nil {
676		t.Fatal(err)
677	}
678	if _, err := encoder.ReadFrom(strings.NewReader("0")); err != nil {
679		t.Fatal(err)
680	}
681	if err := encoder.Close(); err != nil {
682		t.Fatal(err)
683	}
684
685	dec, _ := NewReader(nil)
686	toDec := buffer.Bytes()
687	toDec = append(toDec, toDec...)
688	decoded, err := dec.DecodeAll(toDec, nil)
689	if err != nil {
690		t.Fatal(err)
691	}
692
693	if !bytes.Equal([]byte("00"), decoded) {
694		t.Logf("encoded: % x\n", buffer.Bytes())
695		t.Fatalf("output mismatch, got %s", string(decoded))
696	}
697	dec.Close()
698}
699
700func TestInterleavedWriteReadFrom(t *testing.T) {
701	var encoded bytes.Buffer
702
703	enc, err := NewWriter(&encoded)
704	if err != nil {
705		t.Fatal(err)
706	}
707
708	if _, err := enc.Write([]byte("write1")); err != nil {
709		t.Fatal(err)
710	}
711	if _, err := enc.Write([]byte("write2")); err != nil {
712		t.Fatal(err)
713	}
714	if _, err := enc.ReadFrom(strings.NewReader("readfrom1")); err != nil {
715		t.Fatal(err)
716	}
717	if _, err := enc.Write([]byte("write3")); err != nil {
718		t.Fatal(err)
719	}
720
721	if err := enc.Close(); err != nil {
722		t.Fatal(err)
723	}
724
725	dec, err := NewReader(&encoded)
726	if err != nil {
727		t.Fatal(err)
728	}
729	defer dec.Close()
730
731	gotb, err := ioutil.ReadAll(dec)
732	if err != nil {
733		t.Fatal(err)
734	}
735	got := string(gotb)
736
737	if want := "write1write2readfrom1write3"; got != want {
738		t.Errorf("got decoded %q, want %q", got, want)
739	}
740}
741
742func TestEncoder_EncodeAllEmpty(t *testing.T) {
743	if testing.Short() {
744		t.SkipNow()
745	}
746	var in []byte
747
748	e, err := NewWriter(nil, WithZeroFrames(true))
749	if err != nil {
750		t.Fatal(err)
751	}
752	defer e.Close()
753	dst := e.EncodeAll(in, nil)
754	if len(dst) == 0 {
755		t.Fatal("Requested zero frame, but got nothing.")
756	}
757	t.Log("Block Encoder len", len(in), "-> zstd len", len(dst), dst)
758
759	dec, err := NewReader(nil, WithDecoderMaxMemory(220<<20))
760	if err != nil {
761		t.Fatal(err)
762	}
763	defer dec.Close()
764	decoded, err := dec.DecodeAll(dst, nil)
765	if err != nil {
766		t.Error(err, len(decoded))
767	}
768	if !bytes.Equal(decoded, in) {
769		t.Fatal("Decoded does not match")
770	}
771
772	// Test buffer writer.
773	var buf bytes.Buffer
774	e.Reset(&buf)
775	err = e.Close()
776	if err != nil {
777		t.Fatal(err)
778	}
779	dst = buf.Bytes()
780	if len(dst) == 0 {
781		t.Fatal("Requested zero frame, but got nothing.")
782	}
783	t.Log("Buffer Encoder len", len(in), "-> zstd len", len(dst))
784
785	decoded, err = dec.DecodeAll(dst, nil)
786	if err != nil {
787		t.Error(err, len(decoded))
788	}
789	if !bytes.Equal(decoded, in) {
790		t.Fatal("Decoded does not match")
791	}
792
793	t.Log("Encoded content matched")
794}
795
796func TestEncoder_EncodeAllEnwik9(t *testing.T) {
797	if false || testing.Short() {
798		t.SkipNow()
799	}
800	file := "testdata/enwik9.zst"
801	f, err := os.Open(file)
802	if err != nil {
803		if os.IsNotExist(err) {
804			t.Skip("To run extended tests, download http://mattmahoney.net/dc/enwik9.zip unzip it \n" +
805				"compress it with 'zstd -15 -T0 enwik9' and place it in " + file)
806		}
807	}
808	dec, err := NewReader(f)
809	if err != nil {
810		t.Fatal(err)
811	}
812	defer dec.Close()
813	in, err := ioutil.ReadAll(dec)
814	if err != nil {
815		t.Fatal(err)
816	}
817
818	start := time.Now()
819	var e Encoder
820	dst := e.EncodeAll(in, nil)
821	t.Log("Simple Encoder len", len(in), "-> zstd len", len(dst))
822	mbpersec := (float64(len(in)) / (1024 * 1024)) / (float64(time.Since(start)) / (float64(time.Second)))
823	t.Logf("Encoded %d bytes with %.2f MB/s", len(in), mbpersec)
824	decoded, err := dec.DecodeAll(dst, nil)
825	if err != nil {
826		t.Error(err, len(decoded))
827	}
828	if !bytes.Equal(decoded, in) {
829		ioutil.WriteFile("testdata/"+t.Name()+"-enwik9.got", decoded, os.ModePerm)
830		t.Fatal("Decoded does not match")
831	}
832	t.Log("Encoded content matched")
833}
834
835func BenchmarkEncoder_EncodeAllXML(b *testing.B) {
836	f, err := os.Open("testdata/xml.zst")
837	if err != nil {
838		b.Fatal(err)
839	}
840	dec, err := NewReader(f)
841	if err != nil {
842		b.Fatal(err)
843	}
844	in, err := ioutil.ReadAll(dec)
845	if err != nil {
846		b.Fatal(err)
847	}
848	dec.Close()
849
850	enc := Encoder{}
851	dst := enc.EncodeAll(in, nil)
852	wantSize := len(dst)
853	b.Log("Output size:", len(dst))
854	b.ResetTimer()
855	b.ReportAllocs()
856	b.SetBytes(int64(len(in)))
857	for i := 0; i < b.N; i++ {
858		dst := enc.EncodeAll(in, dst[:0])
859		if len(dst) != wantSize {
860			b.Fatal(len(dst), "!=", wantSize)
861		}
862	}
863}
864
865func BenchmarkEncoder_EncodeAllSimple(b *testing.B) {
866	f, err := os.Open("testdata/z000028")
867	if err != nil {
868		b.Fatal(err)
869	}
870	in, err := ioutil.ReadAll(f)
871	if err != nil {
872		b.Fatal(err)
873	}
874
875	for level := EncoderLevel(speedNotSet + 1); level < speedLast; level++ {
876		b.Run(level.String(), func(b *testing.B) {
877			enc, err := NewWriter(nil, WithEncoderConcurrency(1), WithEncoderLevel(level))
878			if err != nil {
879				b.Fatal(err)
880			}
881			defer enc.Close()
882			dst := enc.EncodeAll(in, nil)
883			wantSize := len(dst)
884			b.ResetTimer()
885			b.ReportAllocs()
886			b.SetBytes(int64(len(in)))
887			for i := 0; i < b.N; i++ {
888				dst := enc.EncodeAll(in, dst[:0])
889				if len(dst) != wantSize {
890					b.Fatal(len(dst), "!=", wantSize)
891				}
892			}
893		})
894	}
895}
896
897func BenchmarkEncoder_EncodeAllSimple4K(b *testing.B) {
898	f, err := os.Open("testdata/z000028")
899	if err != nil {
900		b.Fatal(err)
901	}
902	in, err := ioutil.ReadAll(f)
903	if err != nil {
904		b.Fatal(err)
905	}
906	in = in[:4096]
907
908	for level := EncoderLevel(speedNotSet + 1); level < speedLast; level++ {
909		b.Run(level.String(), func(b *testing.B) {
910			enc, err := NewWriter(nil, WithEncoderConcurrency(1), WithEncoderLevel(level))
911			if err != nil {
912				b.Fatal(err)
913			}
914			defer enc.Close()
915			dst := enc.EncodeAll(in, nil)
916			wantSize := len(dst)
917			b.ResetTimer()
918			b.ReportAllocs()
919			b.SetBytes(int64(len(in)))
920			for i := 0; i < b.N; i++ {
921				dst := enc.EncodeAll(in, dst[:0])
922				if len(dst) != wantSize {
923					b.Fatal(len(dst), "!=", wantSize)
924				}
925			}
926		})
927	}
928}
929
930func BenchmarkEncoder_EncodeAllHTML(b *testing.B) {
931	f, err := os.Open("../testdata/html.txt")
932	if err != nil {
933		b.Fatal(err)
934	}
935	in, err := ioutil.ReadAll(f)
936	if err != nil {
937		b.Fatal(err)
938	}
939
940	enc := Encoder{}
941	dst := enc.EncodeAll(in, nil)
942	wantSize := len(dst)
943	b.ResetTimer()
944	b.ReportAllocs()
945	b.SetBytes(int64(len(in)))
946	for i := 0; i < b.N; i++ {
947		dst := enc.EncodeAll(in, dst[:0])
948		if len(dst) != wantSize {
949			b.Fatal(len(dst), "!=", wantSize)
950		}
951	}
952}
953
954func BenchmarkEncoder_EncodeAllTwain(b *testing.B) {
955	f, err := os.Open("../testdata/Mark.Twain-Tom.Sawyer.txt")
956	if err != nil {
957		b.Fatal(err)
958	}
959	in, err := ioutil.ReadAll(f)
960	if err != nil {
961		b.Fatal(err)
962	}
963
964	enc := Encoder{}
965	dst := enc.EncodeAll(in, nil)
966	wantSize := len(dst)
967	b.ResetTimer()
968	b.ReportAllocs()
969	b.SetBytes(int64(len(in)))
970	for i := 0; i < b.N; i++ {
971		dst := enc.EncodeAll(in, dst[:0])
972		if len(dst) != wantSize {
973			b.Fatal(len(dst), "!=", wantSize)
974		}
975	}
976}
977
978func BenchmarkEncoder_EncodeAllPi(b *testing.B) {
979	f, err := os.Open("../testdata/pi.txt")
980	if err != nil {
981		b.Fatal(err)
982	}
983	in, err := ioutil.ReadAll(f)
984	if err != nil {
985		b.Fatal(err)
986	}
987
988	enc := Encoder{}
989	dst := enc.EncodeAll(in, nil)
990	wantSize := len(dst)
991	b.ResetTimer()
992	b.ReportAllocs()
993	b.SetBytes(int64(len(in)))
994	for i := 0; i < b.N; i++ {
995		dst := enc.EncodeAll(in, dst[:0])
996		if len(dst) != wantSize {
997			b.Fatal(len(dst), "!=", wantSize)
998		}
999	}
1000}
1001
1002func BenchmarkRandom4KEncodeAllFastest(b *testing.B) {
1003	rng := rand.New(rand.NewSource(1))
1004	data := make([]byte, 4<<10)
1005	for i := range data {
1006		data[i] = uint8(rng.Intn(256))
1007	}
1008	enc, _ := NewWriter(nil, WithEncoderLevel(SpeedFastest), WithEncoderConcurrency(1))
1009	defer enc.Close()
1010	dst := enc.EncodeAll(data, nil)
1011	wantSize := len(dst)
1012	b.ResetTimer()
1013	b.ReportAllocs()
1014	b.SetBytes(int64(len(data)))
1015	for i := 0; i < b.N; i++ {
1016		dst := enc.EncodeAll(data, dst[:0])
1017		if len(dst) != wantSize {
1018			b.Fatal(len(dst), "!=", wantSize)
1019		}
1020	}
1021}
1022
1023func BenchmarkRandom10MBEncodeAllFastest(b *testing.B) {
1024	rng := rand.New(rand.NewSource(1))
1025	data := make([]byte, 10<<20)
1026	rng.Read(data)
1027	enc, _ := NewWriter(nil, WithEncoderLevel(SpeedFastest), WithEncoderConcurrency(1))
1028	defer enc.Close()
1029	dst := enc.EncodeAll(data, nil)
1030	wantSize := len(dst)
1031	b.ResetTimer()
1032	b.ReportAllocs()
1033	b.SetBytes(int64(len(data)))
1034	for i := 0; i < b.N; i++ {
1035		dst := enc.EncodeAll(data, dst[:0])
1036		if len(dst) != wantSize {
1037			b.Fatal(len(dst), "!=", wantSize)
1038		}
1039	}
1040}
1041
1042func BenchmarkRandom4KEncodeAllDefault(b *testing.B) {
1043	rng := rand.New(rand.NewSource(1))
1044	data := make([]byte, 4<<10)
1045	rng.Read(data)
1046	enc, _ := NewWriter(nil, WithEncoderLevel(SpeedDefault), WithEncoderConcurrency(1))
1047	defer enc.Close()
1048	dst := enc.EncodeAll(data, nil)
1049	wantSize := len(dst)
1050	b.ResetTimer()
1051	b.ReportAllocs()
1052	b.SetBytes(int64(len(data)))
1053	for i := 0; i < b.N; i++ {
1054		dst := enc.EncodeAll(data, dst[:0])
1055		if len(dst) != wantSize {
1056			b.Fatal(len(dst), "!=", wantSize)
1057		}
1058	}
1059}
1060
1061func BenchmarkRandomEncodeAllDefault(b *testing.B) {
1062	rng := rand.New(rand.NewSource(1))
1063	data := make([]byte, 10<<20)
1064	rng.Read(data)
1065	enc, _ := NewWriter(nil, WithEncoderLevel(SpeedDefault), WithEncoderConcurrency(1))
1066	defer enc.Close()
1067	dst := enc.EncodeAll(data, nil)
1068	wantSize := len(dst)
1069	b.ResetTimer()
1070	b.ReportAllocs()
1071	b.SetBytes(int64(len(data)))
1072	for i := 0; i < b.N; i++ {
1073		dst := enc.EncodeAll(data, dst[:0])
1074		if len(dst) != wantSize {
1075			b.Fatal(len(dst), "!=", wantSize)
1076		}
1077	}
1078}
1079
1080func BenchmarkRandom10MBEncoderFastest(b *testing.B) {
1081	rng := rand.New(rand.NewSource(1))
1082	data := make([]byte, 10<<20)
1083	rng.Read(data)
1084	wantSize := int64(len(data))
1085	enc, _ := NewWriter(ioutil.Discard, WithEncoderLevel(SpeedFastest))
1086	defer enc.Close()
1087	n, err := io.Copy(enc, bytes.NewBuffer(data))
1088	if err != nil {
1089		b.Fatal(err)
1090	}
1091	if n != wantSize {
1092		b.Fatal(n, "!=", wantSize)
1093	}
1094	b.ResetTimer()
1095	b.ReportAllocs()
1096	b.SetBytes(wantSize)
1097	for i := 0; i < b.N; i++ {
1098		enc.Reset(ioutil.Discard)
1099		n, err := io.Copy(enc, bytes.NewBuffer(data))
1100		if err != nil {
1101			b.Fatal(err)
1102		}
1103		if n != wantSize {
1104			b.Fatal(n, "!=", wantSize)
1105		}
1106	}
1107}
1108
1109func BenchmarkRandomEncoderDefault(b *testing.B) {
1110	rng := rand.New(rand.NewSource(1))
1111	data := make([]byte, 10<<20)
1112	rng.Read(data)
1113	wantSize := int64(len(data))
1114	enc, _ := NewWriter(ioutil.Discard, WithEncoderLevel(SpeedDefault))
1115	defer enc.Close()
1116	n, err := io.Copy(enc, bytes.NewBuffer(data))
1117	if err != nil {
1118		b.Fatal(err)
1119	}
1120	if n != wantSize {
1121		b.Fatal(n, "!=", wantSize)
1122	}
1123	b.ResetTimer()
1124	b.ReportAllocs()
1125	b.SetBytes(wantSize)
1126	for i := 0; i < b.N; i++ {
1127		enc.Reset(ioutil.Discard)
1128		n, err := io.Copy(enc, bytes.NewBuffer(data))
1129		if err != nil {
1130			b.Fatal(err)
1131		}
1132		if n != wantSize {
1133			b.Fatal(n, "!=", wantSize)
1134		}
1135	}
1136}
1137