1// Copyright 2012 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package flate
6
7import (
8	"archive/zip"
9	"bytes"
10	"fmt"
11	"io"
12	"io/ioutil"
13	"math"
14	"math/rand"
15	"runtime"
16	"strconv"
17	"strings"
18	"testing"
19)
20
21func TestWriterRegression(t *testing.T) {
22	data, err := ioutil.ReadFile("testdata/regression.zip")
23	if err != nil {
24		t.Fatal(err)
25	}
26	for level := HuffmanOnly; level <= BestCompression; level++ {
27		t.Run(fmt.Sprint("level_", level), func(t *testing.T) {
28			zr, err := zip.NewReader(bytes.NewReader(data), int64(len(data)))
29			if err != nil {
30				t.Fatal(err)
31			}
32
33			for _, tt := range zr.File {
34				if !strings.HasSuffix(t.Name(), "") {
35					continue
36				}
37
38				t.Run(tt.Name, func(t *testing.T) {
39					r, err := tt.Open()
40					if err != nil {
41						t.Error(err)
42						return
43					}
44					in, err := ioutil.ReadAll(r)
45					if err != nil {
46						t.Error(err)
47					}
48					msg := "level " + strconv.Itoa(level) + ":"
49					buf := new(bytes.Buffer)
50					fw, err := NewWriter(buf, level)
51					if err != nil {
52						t.Fatal(msg + err.Error())
53					}
54					n, err := fw.Write(in)
55					if n != len(in) {
56						t.Fatal(msg + "short write")
57					}
58					if err != nil {
59						t.Fatal(msg + err.Error())
60					}
61					err = fw.Close()
62					if err != nil {
63						t.Fatal(msg + err.Error())
64					}
65					fr1 := NewReader(buf)
66					data2, err := ioutil.ReadAll(fr1)
67					if err != nil {
68						t.Fatal(msg + err.Error())
69					}
70					if bytes.Compare(in, data2) != 0 {
71						t.Fatal(msg + "not equal")
72					}
73					// Do it again...
74					msg = "level " + strconv.Itoa(level) + " (reset):"
75					buf.Reset()
76					fw.Reset(buf)
77					n, err = fw.Write(in)
78					if n != len(in) {
79						t.Fatal(msg + "short write")
80					}
81					if err != nil {
82						t.Fatal(msg + err.Error())
83					}
84					err = fw.Close()
85					if err != nil {
86						t.Fatal(msg + err.Error())
87					}
88					fr1 = NewReader(buf)
89					data2, err = ioutil.ReadAll(fr1)
90					if err != nil {
91						t.Fatal(msg + err.Error())
92					}
93					if bytes.Compare(in, data2) != 0 {
94						t.Fatal(msg + "not equal")
95					}
96				})
97			}
98		})
99	}
100}
101
102func benchmarkEncoder(b *testing.B, testfile, level, n int) {
103	b.SetBytes(int64(n))
104	buf0, err := ioutil.ReadFile(testfiles[testfile])
105	if err != nil {
106		b.Fatal(err)
107	}
108	if len(buf0) == 0 {
109		b.Fatalf("test file %q has no data", testfiles[testfile])
110	}
111	buf1 := make([]byte, n)
112	for i := 0; i < n; i += len(buf0) {
113		if len(buf0) > n-i {
114			buf0 = buf0[:n-i]
115		}
116		copy(buf1[i:], buf0)
117	}
118	buf0 = nil
119	runtime.GC()
120	w, err := NewWriter(ioutil.Discard, level)
121	b.ResetTimer()
122	b.ReportAllocs()
123	for i := 0; i < b.N; i++ {
124		w.Reset(ioutil.Discard)
125		_, err = w.Write(buf1)
126		if err != nil {
127			b.Fatal(err)
128		}
129		err = w.Close()
130		if err != nil {
131			b.Fatal(err)
132		}
133	}
134}
135
136func BenchmarkEncodeDigitsConstant1e4(b *testing.B) { benchmarkEncoder(b, digits, constant, 1e4) }
137func BenchmarkEncodeDigitsConstant1e5(b *testing.B) { benchmarkEncoder(b, digits, constant, 1e5) }
138func BenchmarkEncodeDigitsConstant1e6(b *testing.B) { benchmarkEncoder(b, digits, constant, 1e6) }
139func BenchmarkEncodeDigitsSpeed1e4(b *testing.B)    { benchmarkEncoder(b, digits, speed, 1e4) }
140func BenchmarkEncodeDigitsSpeed1e5(b *testing.B)    { benchmarkEncoder(b, digits, speed, 1e5) }
141func BenchmarkEncodeDigitsSpeed1e6(b *testing.B)    { benchmarkEncoder(b, digits, speed, 1e6) }
142func BenchmarkEncodeDigitsDefault1e4(b *testing.B)  { benchmarkEncoder(b, digits, default_, 1e4) }
143func BenchmarkEncodeDigitsDefault1e5(b *testing.B)  { benchmarkEncoder(b, digits, default_, 1e5) }
144func BenchmarkEncodeDigitsDefault1e6(b *testing.B)  { benchmarkEncoder(b, digits, default_, 1e6) }
145func BenchmarkEncodeDigitsCompress1e4(b *testing.B) { benchmarkEncoder(b, digits, compress, 1e4) }
146func BenchmarkEncodeDigitsCompress1e5(b *testing.B) { benchmarkEncoder(b, digits, compress, 1e5) }
147func BenchmarkEncodeDigitsCompress1e6(b *testing.B) { benchmarkEncoder(b, digits, compress, 1e6) }
148func BenchmarkEncodeDigitsSL1e4(b *testing.B)       { benchmarkStatelessEncoder(b, digits, 1e4) }
149func BenchmarkEncodeDigitsSL1e5(b *testing.B)       { benchmarkStatelessEncoder(b, digits, 1e5) }
150func BenchmarkEncodeDigitsSL1e6(b *testing.B)       { benchmarkStatelessEncoder(b, digits, 1e6) }
151func BenchmarkEncodeTwainConstant1e4(b *testing.B)  { benchmarkEncoder(b, twain, constant, 1e4) }
152func BenchmarkEncodeTwainConstant1e5(b *testing.B)  { benchmarkEncoder(b, twain, constant, 1e5) }
153func BenchmarkEncodeTwainConstant1e6(b *testing.B)  { benchmarkEncoder(b, twain, constant, 1e6) }
154func BenchmarkEncodeTwainSpeed1e4(b *testing.B)     { benchmarkEncoder(b, twain, speed, 1e4) }
155func BenchmarkEncodeTwainSpeed1e5(b *testing.B)     { benchmarkEncoder(b, twain, speed, 1e5) }
156func BenchmarkEncodeTwainSpeed1e6(b *testing.B)     { benchmarkEncoder(b, twain, speed, 1e6) }
157func BenchmarkEncodeTwainDefault1e4(b *testing.B)   { benchmarkEncoder(b, twain, default_, 1e4) }
158func BenchmarkEncodeTwainDefault1e5(b *testing.B)   { benchmarkEncoder(b, twain, default_, 1e5) }
159func BenchmarkEncodeTwainDefault1e6(b *testing.B)   { benchmarkEncoder(b, twain, default_, 1e6) }
160func BenchmarkEncodeTwainCompress1e4(b *testing.B)  { benchmarkEncoder(b, twain, compress, 1e4) }
161func BenchmarkEncodeTwainCompress1e5(b *testing.B)  { benchmarkEncoder(b, twain, compress, 1e5) }
162func BenchmarkEncodeTwainCompress1e6(b *testing.B)  { benchmarkEncoder(b, twain, compress, 1e6) }
163func BenchmarkEncodeTwainSL1e4(b *testing.B)        { benchmarkStatelessEncoder(b, twain, 1e4) }
164func BenchmarkEncodeTwainSL1e5(b *testing.B)        { benchmarkStatelessEncoder(b, twain, 1e5) }
165func BenchmarkEncodeTwainSL1e6(b *testing.B)        { benchmarkStatelessEncoder(b, twain, 1e6) }
166
167func benchmarkStatelessEncoder(b *testing.B, testfile, n int) {
168	b.SetBytes(int64(n))
169	buf0, err := ioutil.ReadFile(testfiles[testfile])
170	if err != nil {
171		b.Fatal(err)
172	}
173	if len(buf0) == 0 {
174		b.Fatalf("test file %q has no data", testfiles[testfile])
175	}
176	buf1 := make([]byte, n)
177	for i := 0; i < n; i += len(buf0) {
178		if len(buf0) > n-i {
179			buf0 = buf0[:n-i]
180		}
181		copy(buf1[i:], buf0)
182	}
183	buf0 = nil
184	runtime.GC()
185	b.ResetTimer()
186	b.ReportAllocs()
187	for i := 0; i < b.N; i++ {
188		w := NewStatelessWriter(ioutil.Discard)
189		_, err = w.Write(buf1)
190		if err != nil {
191			b.Fatal(err)
192		}
193		err = w.Close()
194		if err != nil {
195			b.Fatal(err)
196		}
197	}
198}
199
200// A writer that fails after N writes.
201type errorWriter struct {
202	N int
203}
204
205func (e *errorWriter) Write(b []byte) (int, error) {
206	if e.N <= 0 {
207		return 0, io.ErrClosedPipe
208	}
209	e.N--
210	return len(b), nil
211}
212
213// Test if errors from the underlying writer is passed upwards.
214func TestWriteError(t *testing.T) {
215	buf := new(bytes.Buffer)
216	n := 65536
217	if !testing.Short() {
218		n *= 4
219	}
220	for i := 0; i < n; i++ {
221		fmt.Fprintf(buf, "asdasfasf%d%dfghfgujyut%dyutyu\n", i, i, i)
222	}
223	in := buf.Bytes()
224	// We create our own buffer to control number of writes.
225	copyBuf := make([]byte, 128)
226	for l := 0; l < 10; l++ {
227		for fail := 1; fail <= 256; fail *= 2 {
228			// Fail after 'fail' writes
229			ew := &errorWriter{N: fail}
230			w, err := NewWriter(ew, l)
231			if err != nil {
232				t.Fatalf("NewWriter: level %d: %v", l, err)
233			}
234			n, err := copyBuffer(w, bytes.NewBuffer(in), copyBuf)
235			if err == nil {
236				t.Fatalf("Level %d: Expected an error, writer was %#v", l, ew)
237			}
238			n2, err := w.Write([]byte{1, 2, 2, 3, 4, 5})
239			if n2 != 0 {
240				t.Fatal("Level", l, "Expected 0 length write, got", n)
241			}
242			if err == nil {
243				t.Fatal("Level", l, "Expected an error")
244			}
245			err = w.Flush()
246			if err == nil {
247				t.Fatal("Level", l, "Expected an error on flush")
248			}
249			err = w.Close()
250			if err == nil {
251				t.Fatal("Level", l, "Expected an error on close")
252			}
253
254			w.Reset(ioutil.Discard)
255			n2, err = w.Write([]byte{1, 2, 3, 4, 5, 6})
256			if err != nil {
257				t.Fatal("Level", l, "Got unexpected error after reset:", err)
258			}
259			if n2 == 0 {
260				t.Fatal("Level", l, "Got 0 length write, expected > 0")
261			}
262			if testing.Short() {
263				return
264			}
265		}
266	}
267}
268
269// Test if errors from the underlying writer is passed upwards.
270func TestWriter_Reset(t *testing.T) {
271	buf := new(bytes.Buffer)
272	n := 65536
273	if !testing.Short() {
274		n *= 4
275	}
276	for i := 0; i < n; i++ {
277		fmt.Fprintf(buf, "asdasfasf%d%dfghfgujyut%dyutyu\n", i, i, i)
278	}
279	in := buf.Bytes()
280	for l := 0; l < 10; l++ {
281		l := l
282		if testing.Short() && l > 1 {
283			continue
284		}
285		t.Run(fmt.Sprintf("level-%d", l), func(t *testing.T) {
286			t.Parallel()
287			offset := 1
288			if testing.Short() {
289				offset = 256
290			}
291			for ; offset <= 256; offset *= 2 {
292				// Fail after 'fail' writes
293				w, err := NewWriter(ioutil.Discard, l)
294				if err != nil {
295					t.Fatalf("NewWriter: level %d: %v", l, err)
296				}
297				if w.d.fast == nil {
298					t.Skip("Not Fast...")
299					return
300				}
301				for i := 0; i < (bufferReset-len(in)-offset-maxMatchOffset)/maxMatchOffset; i++ {
302					// skip ahead to where we are close to wrap around...
303					w.d.fast.Reset()
304				}
305				w.d.fast.Reset()
306				_, err = w.Write(in)
307				if err != nil {
308					t.Fatal(err)
309				}
310				for i := 0; i < 50; i++ {
311					// skip ahead again... This should wrap around...
312					w.d.fast.Reset()
313				}
314				w.d.fast.Reset()
315
316				_, err = w.Write(in)
317				if err != nil {
318					t.Fatal(err)
319				}
320				for i := 0; i < (math.MaxUint32-bufferReset)/maxMatchOffset; i++ {
321					// skip ahead to where we are close to wrap around...
322					w.d.fast.Reset()
323				}
324
325				_, err = w.Write(in)
326				if err != nil {
327					t.Fatal(err)
328				}
329				err = w.Close()
330				if err != nil {
331					t.Fatal(err)
332				}
333			}
334		})
335	}
336}
337
338func TestDeterministicL1(t *testing.T)  { testDeterministic(1, t) }
339func TestDeterministicL2(t *testing.T)  { testDeterministic(2, t) }
340func TestDeterministicL3(t *testing.T)  { testDeterministic(3, t) }
341func TestDeterministicL4(t *testing.T)  { testDeterministic(4, t) }
342func TestDeterministicL5(t *testing.T)  { testDeterministic(5, t) }
343func TestDeterministicL6(t *testing.T)  { testDeterministic(6, t) }
344func TestDeterministicL7(t *testing.T)  { testDeterministic(7, t) }
345func TestDeterministicL8(t *testing.T)  { testDeterministic(8, t) }
346func TestDeterministicL9(t *testing.T)  { testDeterministic(9, t) }
347func TestDeterministicL0(t *testing.T)  { testDeterministic(0, t) }
348func TestDeterministicLM2(t *testing.T) { testDeterministic(-2, t) }
349
350func testDeterministic(i int, t *testing.T) {
351	// Test so much we cross a good number of block boundaries.
352	var length = maxStoreBlockSize*30 + 500
353	if testing.Short() {
354		length /= 10
355	}
356
357	// Create a random, but compressible stream.
358	rng := rand.New(rand.NewSource(1))
359	t1 := make([]byte, length)
360	for i := range t1 {
361		t1[i] = byte(rng.Int63() & 7)
362	}
363
364	// Do our first encode.
365	var b1 bytes.Buffer
366	br := bytes.NewBuffer(t1)
367	w, err := NewWriter(&b1, i)
368	if err != nil {
369		t.Fatal(err)
370	}
371	// Use a very small prime sized buffer.
372	cbuf := make([]byte, 787)
373	_, err = copyBuffer(w, br, cbuf)
374	if err != nil {
375		t.Fatal(err)
376	}
377	w.Close()
378
379	// We choose a different buffer size,
380	// bigger than a maximum block, and also a prime.
381	var b2 bytes.Buffer
382	cbuf = make([]byte, 81761)
383	br2 := bytes.NewBuffer(t1)
384	w2, err := NewWriter(&b2, i)
385	if err != nil {
386		t.Fatal(err)
387	}
388	_, err = copyBuffer(w2, br2, cbuf)
389	if err != nil {
390		t.Fatal(err)
391	}
392	w2.Close()
393
394	b1b := b1.Bytes()
395	b2b := b2.Bytes()
396
397	if !bytes.Equal(b1b, b2b) {
398		t.Errorf("level %d did not produce deterministic result, result mismatch, len(a) = %d, len(b) = %d", i, len(b1b), len(b2b))
399	}
400
401	// Test using io.WriterTo interface.
402	var b3 bytes.Buffer
403	br = bytes.NewBuffer(t1)
404	w, err = NewWriter(&b3, i)
405	if err != nil {
406		t.Fatal(err)
407	}
408	_, err = br.WriteTo(w)
409	if err != nil {
410		t.Fatal(err)
411	}
412	w.Close()
413
414	b3b := b3.Bytes()
415	if !bytes.Equal(b1b, b3b) {
416		t.Errorf("level %d (io.WriterTo) did not produce deterministic result, result mismatch, len(a) = %d, len(b) = %d", i, len(b1b), len(b3b))
417	}
418}
419
420// copyBuffer is a copy of io.CopyBuffer, since we want to support older go versions.
421// This is modified to never use io.WriterTo or io.ReaderFrom interfaces.
422func copyBuffer(dst io.Writer, src io.Reader, buf []byte) (written int64, err error) {
423	if buf == nil {
424		buf = make([]byte, 32*1024)
425	}
426	for {
427		nr, er := src.Read(buf)
428		if nr > 0 {
429			nw, ew := dst.Write(buf[0:nr])
430			if nw > 0 {
431				written += int64(nw)
432			}
433			if ew != nil {
434				err = ew
435				break
436			}
437			if nr != nw {
438				err = io.ErrShortWrite
439				break
440			}
441		}
442		if er == io.EOF {
443			break
444		}
445		if er != nil {
446			err = er
447			break
448		}
449	}
450	return written, err
451}
452