1// Copyright 2016 Google Inc. All Rights Reserved.
2//
3// Distributed under MIT license.
4// See file LICENSE for detail or copy at https://opensource.org/licenses/MIT
5
6package cbrotli
7
8import (
9	"bytes"
10	"fmt"
11	"io"
12	"io/ioutil"
13	"math"
14	"math/rand"
15	"testing"
16	"time"
17)
18
19func checkCompressedData(compressedData, wantOriginalData []byte) error {
20	uncompressed, err := Decode(compressedData)
21	if err != nil {
22		return fmt.Errorf("brotli decompress failed: %v", err)
23	}
24	if !bytes.Equal(uncompressed, wantOriginalData) {
25		if len(wantOriginalData) != len(uncompressed) {
26			return fmt.Errorf(""+
27				"Data doesn't uncompress to the original value.\n"+
28				"Length of original: %v\n"+
29				"Length of uncompressed: %v",
30				len(wantOriginalData), len(uncompressed))
31		}
32		for i := range wantOriginalData {
33			if wantOriginalData[i] != uncompressed[i] {
34				return fmt.Errorf(""+
35					"Data doesn't uncompress to the original value.\n"+
36					"Original at %v is %v\n"+
37					"Uncompressed at %v is %v",
38					i, wantOriginalData[i], i, uncompressed[i])
39			}
40		}
41	}
42	return nil
43}
44
45func TestEncoderNoWrite(t *testing.T) {
46	out := bytes.Buffer{}
47	e := NewWriter(&out, WriterOptions{Quality: 5})
48	if err := e.Close(); err != nil {
49		t.Errorf("Close()=%v, want nil", err)
50	}
51	// Check Write after close.
52	if _, err := e.Write([]byte("hi")); err == nil {
53		t.Errorf("No error after Close() + Write()")
54	}
55}
56
57func TestEncoderEmptyWrite(t *testing.T) {
58	out := bytes.Buffer{}
59	e := NewWriter(&out, WriterOptions{Quality: 5})
60	n, err := e.Write([]byte(""))
61	if n != 0 || err != nil {
62		t.Errorf("Write()=%v,%v, want 0, nil", n, err)
63	}
64	if err := e.Close(); err != nil {
65		t.Errorf("Close()=%v, want nil", err)
66	}
67}
68
69func TestWriter(t *testing.T) {
70	// Test basic encoder usage.
71	input := []byte("<html><body><H1>Hello world</H1></body></html>")
72	out := bytes.Buffer{}
73	e := NewWriter(&out, WriterOptions{Quality: 1})
74	in := bytes.NewReader([]byte(input))
75	n, err := io.Copy(e, in)
76	if err != nil {
77		t.Errorf("Copy Error: %v", err)
78	}
79	if int(n) != len(input) {
80		t.Errorf("Copy() n=%v, want %v", n, len(input))
81	}
82	if err := e.Close(); err != nil {
83		t.Errorf("Close Error after copied %d bytes: %v", n, err)
84	}
85	if err := checkCompressedData(out.Bytes(), input); err != nil {
86		t.Error(err)
87	}
88}
89
90func TestEncoderStreams(t *testing.T) {
91	// Test that output is streamed.
92	// Adjust window size to ensure the encoder outputs at least enough bytes
93	// to fill the window.
94	const lgWin = 16
95	windowSize := int(math.Pow(2, lgWin))
96	input := make([]byte, 8*windowSize)
97	rand.Read(input)
98	out := bytes.Buffer{}
99	e := NewWriter(&out, WriterOptions{Quality: 11, LGWin: lgWin})
100	halfInput := input[:len(input)/2]
101	in := bytes.NewReader(halfInput)
102
103	n, err := io.Copy(e, in)
104	if err != nil {
105		t.Errorf("Copy Error: %v", err)
106	}
107
108	// We've fed more data than the sliding window size. Check that some
109	// compressed data has been output.
110	if out.Len() == 0 {
111		t.Errorf("Output length is 0 after %d bytes written", n)
112	}
113	if err := e.Close(); err != nil {
114		t.Errorf("Close Error after copied %d bytes: %v", n, err)
115	}
116	if err := checkCompressedData(out.Bytes(), halfInput); err != nil {
117		t.Error(err)
118	}
119}
120
121func TestEncoderLargeInput(t *testing.T) {
122	input := make([]byte, 1000000)
123	rand.Read(input)
124	out := bytes.Buffer{}
125	e := NewWriter(&out, WriterOptions{Quality: 5})
126	in := bytes.NewReader(input)
127
128	n, err := io.Copy(e, in)
129	if err != nil {
130		t.Errorf("Copy Error: %v", err)
131	}
132	if int(n) != len(input) {
133		t.Errorf("Copy() n=%v, want %v", n, len(input))
134	}
135	if err := e.Close(); err != nil {
136		t.Errorf("Close Error after copied %d bytes: %v", n, err)
137	}
138	if err := checkCompressedData(out.Bytes(), input); err != nil {
139		t.Error(err)
140	}
141}
142
143func TestEncoderFlush(t *testing.T) {
144	input := make([]byte, 1000)
145	rand.Read(input)
146	out := bytes.Buffer{}
147	e := NewWriter(&out, WriterOptions{Quality: 5})
148	in := bytes.NewReader(input)
149	_, err := io.Copy(e, in)
150	if err != nil {
151		t.Fatalf("Copy Error: %v", err)
152	}
153	if err := e.Flush(); err != nil {
154		t.Fatalf("Flush(): %v", err)
155	}
156	if out.Len() == 0 {
157		t.Fatalf("0 bytes written after Flush()")
158	}
159	decompressed := make([]byte, 1000)
160	reader := NewReader(bytes.NewReader(out.Bytes()))
161	n, err := reader.Read(decompressed)
162	if n != len(decompressed) || err != nil {
163		t.Errorf("Expected <%v, nil>, but <%v, %v>", len(decompressed), n, err)
164	}
165	reader.Close()
166	if !bytes.Equal(decompressed, input) {
167		t.Errorf(""+
168			"Decompress after flush: %v\n"+
169			"%q\n"+
170			"want:\n%q",
171			err, decompressed, input)
172	}
173	if err := e.Close(); err != nil {
174		t.Errorf("Close(): %v", err)
175	}
176}
177
178type readerWithTimeout struct {
179	io.ReadCloser
180}
181
182func (r readerWithTimeout) Read(p []byte) (int, error) {
183	type result struct {
184		n   int
185		err error
186	}
187	ch := make(chan result)
188	go func() {
189		n, err := r.ReadCloser.Read(p)
190		ch <- result{n, err}
191	}()
192	select {
193	case result := <-ch:
194		return result.n, result.err
195	case <-time.After(5 * time.Second):
196		return 0, fmt.Errorf("read timed out")
197	}
198}
199
200func TestDecoderStreaming(t *testing.T) {
201	pr, pw := io.Pipe()
202	writer := NewWriter(pw, WriterOptions{Quality: 5, LGWin: 20})
203	reader := readerWithTimeout{NewReader(pr)}
204	defer func() {
205		if err := reader.Close(); err != nil {
206			t.Errorf("reader.Close: %v", err)
207		}
208		go ioutil.ReadAll(pr) // swallow the "EOF" token from writer.Close
209		if err := writer.Close(); err != nil {
210			t.Errorf("writer.Close: %v", err)
211		}
212	}()
213
214	ch := make(chan []byte)
215	errch := make(chan error)
216	go func() {
217		for {
218			segment, ok := <-ch
219			if !ok {
220				return
221			}
222			if n, err := writer.Write(segment); err != nil || n != len(segment) {
223				errch <- fmt.Errorf("write=%v,%v, want %v,%v", n, err, len(segment), nil)
224				return
225			}
226			if err := writer.Flush(); err != nil {
227				errch <- fmt.Errorf("flush: %v", err)
228				return
229			}
230		}
231	}()
232	defer close(ch)
233
234	segments := [...][]byte{
235		[]byte("first"),
236		[]byte("second"),
237		[]byte("third"),
238	}
239	for k, segment := range segments {
240		t.Run(fmt.Sprintf("Segment%d", k), func(t *testing.T) {
241			select {
242			case ch <- segment:
243			case err := <-errch:
244				t.Fatalf("write: %v", err)
245			case <-time.After(5 * time.Second):
246				t.Fatalf("timed out")
247			}
248			wantLen := len(segment)
249			got := make([]byte, wantLen)
250			if n, err := reader.Read(got); err != nil || n != wantLen || !bytes.Equal(got, segment) {
251				t.Fatalf("read[%d]=%q,%v,%v, want %q,%v,%v", k, got, n, err, segment, wantLen, nil)
252			}
253		})
254	}
255}
256
257func TestReader(t *testing.T) {
258	content := bytes.Repeat([]byte("hello world!"), 10000)
259	encoded, _ := Encode(content, WriterOptions{Quality: 5})
260	r := NewReader(bytes.NewReader(encoded))
261	var decodedOutput bytes.Buffer
262	n, err := io.Copy(&decodedOutput, r)
263	if err != nil {
264		t.Fatalf("Copy(): n=%v, err=%v", n, err)
265	}
266	if err := r.Close(); err != nil {
267		t.Errorf("Close(): %v", err)
268	}
269	if got := decodedOutput.Bytes(); !bytes.Equal(got, content) {
270		t.Errorf(""+
271			"Reader output:\n"+
272			"%q\n"+
273			"want:\n"+
274			"<%d bytes>",
275			got, len(content))
276	}
277}
278
279func TestDecode(t *testing.T) {
280	content := bytes.Repeat([]byte("hello world!"), 10000)
281	encoded, _ := Encode(content, WriterOptions{Quality: 5})
282	decoded, err := Decode(encoded)
283	if err != nil {
284		t.Errorf("Decode: %v", err)
285	}
286	if !bytes.Equal(decoded, content) {
287		t.Errorf(""+
288			"Decode content:\n"+
289			"%q\n"+
290			"want:\n"+
291			"<%d bytes>",
292			decoded, len(content))
293	}
294}
295
296func TestDecodeFuzz(t *testing.T) {
297	// Test that the decoder terminates with corrupted input.
298	content := bytes.Repeat([]byte("hello world!"), 100)
299	src := rand.NewSource(0)
300	encoded, err := Encode(content, WriterOptions{Quality: 5})
301	if err != nil {
302		t.Fatalf("Encode(<%d bytes>, _) = _, %s", len(content), err)
303	}
304	if len(encoded) == 0 {
305		t.Fatalf("Encode(<%d bytes>, _) produced empty output", len(content))
306	}
307	for i := 0; i < 100; i++ {
308		enc := append([]byte{}, encoded...)
309		for j := 0; j < 5; j++ {
310			enc[int(src.Int63())%len(enc)] = byte(src.Int63() % 256)
311		}
312		Decode(enc)
313	}
314}
315
316func TestDecodeTrailingData(t *testing.T) {
317	content := bytes.Repeat([]byte("hello world!"), 100)
318	encoded, _ := Encode(content, WriterOptions{Quality: 5})
319	_, err := Decode(append(encoded, 0))
320	if err == nil {
321		t.Errorf("Expected 'excessive input' error")
322	}
323}
324
325func TestEncodeDecode(t *testing.T) {
326	for _, test := range []struct {
327		data    []byte
328		repeats int
329	}{
330		{nil, 0},
331		{[]byte("A"), 1},
332		{[]byte("<html><body><H1>Hello world</H1></body></html>"), 10},
333		{[]byte("<html><body><H1>Hello world</H1></body></html>"), 1000},
334	} {
335		t.Logf("case %q x %d", test.data, test.repeats)
336		input := bytes.Repeat(test.data, test.repeats)
337		encoded, err := Encode(input, WriterOptions{Quality: 5})
338		if err != nil {
339			t.Errorf("Encode: %v", err)
340		}
341		// Inputs are compressible, but may be too small to compress.
342		if maxSize := len(input)/2 + 20; len(encoded) >= maxSize {
343			t.Errorf(""+
344				"Encode returned %d bytes, want <%d\n"+
345				"Encoded=%q",
346				len(encoded), maxSize, encoded)
347		}
348		decoded, err := Decode(encoded)
349		if err != nil {
350			t.Errorf("Decode: %v", err)
351		}
352		if !bytes.Equal(decoded, input) {
353			var want string
354			if len(input) > 320 {
355				want = fmt.Sprintf("<%d bytes>", len(input))
356			} else {
357				want = fmt.Sprintf("%q", input)
358			}
359			t.Errorf(""+
360				"Decode content:\n"+
361				"%q\n"+
362				"want:\n"+
363				"%s",
364				decoded, want)
365		}
366	}
367}
368