1// Copyright 2014-2021 Ulrich Kunitz. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5// Package xz supports the compression and decompression of xz files. It
6// supports version 1.0.4 of the specification without the non-LZMA2
7// filters. See http://tukaani.org/xz/xz-file-format-1.0.4.txt
8package xz
9
10import (
11	"bytes"
12	"errors"
13	"fmt"
14	"hash"
15	"io"
16
17	"github.com/ulikunitz/xz/internal/xlog"
18	"github.com/ulikunitz/xz/lzma"
19)
20
21// ReaderConfig defines the parameters for the xz reader. The
22// SingleStream parameter requests the reader to assume that the
23// underlying stream contains only a single stream.
24type ReaderConfig struct {
25	DictCap      int
26	SingleStream bool
27}
28
29// Verify checks the reader parameters for Validity. Zero values will be
30// replaced by default values.
31func (c *ReaderConfig) Verify() error {
32	if c == nil {
33		return errors.New("xz: reader parameters are nil")
34	}
35	lc := lzma.Reader2Config{DictCap: c.DictCap}
36	if err := lc.Verify(); err != nil {
37		return err
38	}
39	return nil
40}
41
42// Reader supports the reading of one or multiple xz streams.
43type Reader struct {
44	ReaderConfig
45
46	xz io.Reader
47	sr *streamReader
48}
49
50// streamReader decodes a single xz stream
51type streamReader struct {
52	ReaderConfig
53
54	xz      io.Reader
55	br      *blockReader
56	newHash func() hash.Hash
57	h       header
58	index   []record
59}
60
61// NewReader creates a new xz reader using the default parameters.
62// The function reads and checks the header of the first XZ stream. The
63// reader will process multiple streams including padding.
64func NewReader(xz io.Reader) (r *Reader, err error) {
65	return ReaderConfig{}.NewReader(xz)
66}
67
68// NewReader creates an xz stream reader. The created reader will be
69// able to process multiple streams and padding unless a SingleStream
70// has been set in the reader configuration c.
71func (c ReaderConfig) NewReader(xz io.Reader) (r *Reader, err error) {
72	if err = c.Verify(); err != nil {
73		return nil, err
74	}
75	r = &Reader{
76		ReaderConfig: c,
77		xz:           xz,
78	}
79	if r.sr, err = c.newStreamReader(xz); err != nil {
80		if err == io.EOF {
81			err = io.ErrUnexpectedEOF
82		}
83		return nil, err
84	}
85	return r, nil
86}
87
88var errUnexpectedData = errors.New("xz: unexpected data after stream")
89
90// Read reads uncompressed data from the stream.
91func (r *Reader) Read(p []byte) (n int, err error) {
92	for n < len(p) {
93		if r.sr == nil {
94			if r.SingleStream {
95				data := make([]byte, 1)
96				_, err = io.ReadFull(r.xz, data)
97				if err != io.EOF {
98					return n, errUnexpectedData
99				}
100				return n, io.EOF
101			}
102			for {
103				r.sr, err = r.ReaderConfig.newStreamReader(r.xz)
104				if err != errPadding {
105					break
106				}
107			}
108			if err != nil {
109				return n, err
110			}
111		}
112		k, err := r.sr.Read(p[n:])
113		n += k
114		if err != nil {
115			if err == io.EOF {
116				r.sr = nil
117				continue
118			}
119			return n, err
120		}
121	}
122	return n, nil
123}
124
125var errPadding = errors.New("xz: padding (4 zero bytes) encountered")
126
127// newStreamReader creates a new xz stream reader using the given configuration
128// parameters. NewReader reads and checks the header of the xz stream.
129func (c ReaderConfig) newStreamReader(xz io.Reader) (r *streamReader, err error) {
130	if err = c.Verify(); err != nil {
131		return nil, err
132	}
133	data := make([]byte, HeaderLen)
134	if _, err := io.ReadFull(xz, data[:4]); err != nil {
135		return nil, err
136	}
137	if bytes.Equal(data[:4], []byte{0, 0, 0, 0}) {
138		return nil, errPadding
139	}
140	if _, err = io.ReadFull(xz, data[4:]); err != nil {
141		if err == io.EOF {
142			err = io.ErrUnexpectedEOF
143		}
144		return nil, err
145	}
146	r = &streamReader{
147		ReaderConfig: c,
148		xz:           xz,
149		index:        make([]record, 0, 4),
150	}
151	if err = r.h.UnmarshalBinary(data); err != nil {
152		return nil, err
153	}
154	xlog.Debugf("xz header %s", r.h)
155	if r.newHash, err = newHashFunc(r.h.flags); err != nil {
156		return nil, err
157	}
158	return r, nil
159}
160
161// readTail reads the index body and the xz footer.
162func (r *streamReader) readTail() error {
163	index, n, err := readIndexBody(r.xz, len(r.index))
164	if err != nil {
165		if err == io.EOF {
166			err = io.ErrUnexpectedEOF
167		}
168		return err
169	}
170
171	for i, rec := range r.index {
172		if rec != index[i] {
173			return fmt.Errorf("xz: record %d is %v; want %v",
174				i, rec, index[i])
175		}
176	}
177
178	p := make([]byte, footerLen)
179	if _, err = io.ReadFull(r.xz, p); err != nil {
180		if err == io.EOF {
181			err = io.ErrUnexpectedEOF
182		}
183		return err
184	}
185	var f footer
186	if err = f.UnmarshalBinary(p); err != nil {
187		return err
188	}
189	xlog.Debugf("xz footer %s", f)
190	if f.flags != r.h.flags {
191		return errors.New("xz: footer flags incorrect")
192	}
193	if f.indexSize != int64(n)+1 {
194		return errors.New("xz: index size in footer wrong")
195	}
196	return nil
197}
198
199// Read reads actual data from the xz stream.
200func (r *streamReader) Read(p []byte) (n int, err error) {
201	for n < len(p) {
202		if r.br == nil {
203			bh, hlen, err := readBlockHeader(r.xz)
204			if err != nil {
205				if err == errIndexIndicator {
206					if err = r.readTail(); err != nil {
207						return n, err
208					}
209					return n, io.EOF
210				}
211				return n, err
212			}
213			xlog.Debugf("block %v", *bh)
214			r.br, err = r.ReaderConfig.newBlockReader(r.xz, bh,
215				hlen, r.newHash())
216			if err != nil {
217				return n, err
218			}
219		}
220		k, err := r.br.Read(p[n:])
221		n += k
222		if err != nil {
223			if err == io.EOF {
224				r.index = append(r.index, r.br.record())
225				r.br = nil
226			} else {
227				return n, err
228			}
229		}
230	}
231	return n, nil
232}
233
234// countingReader is a reader that counts the bytes read.
235type countingReader struct {
236	r io.Reader
237	n int64
238}
239
240// Read reads data from the wrapped reader and adds it to the n field.
241func (lr *countingReader) Read(p []byte) (n int, err error) {
242	n, err = lr.r.Read(p)
243	lr.n += int64(n)
244	return n, err
245}
246
247// blockReader supports the reading of a block.
248type blockReader struct {
249	lxz       countingReader
250	header    *blockHeader
251	headerLen int
252	n         int64
253	hash      hash.Hash
254	r         io.Reader
255}
256
257// newBlockReader creates a new block reader.
258func (c *ReaderConfig) newBlockReader(xz io.Reader, h *blockHeader,
259	hlen int, hash hash.Hash) (br *blockReader, err error) {
260
261	br = &blockReader{
262		lxz:       countingReader{r: xz},
263		header:    h,
264		headerLen: hlen,
265		hash:      hash,
266	}
267
268	fr, err := c.newFilterReader(&br.lxz, h.filters)
269	if err != nil {
270		return nil, err
271	}
272	if br.hash.Size() != 0 {
273		br.r = io.TeeReader(fr, br.hash)
274	} else {
275		br.r = fr
276	}
277
278	return br, nil
279}
280
281// uncompressedSize returns the uncompressed size of the block.
282func (br *blockReader) uncompressedSize() int64 {
283	return br.n
284}
285
286// compressedSize returns the compressed size of the block.
287func (br *blockReader) compressedSize() int64 {
288	return br.lxz.n
289}
290
291// unpaddedSize computes the unpadded size for the block.
292func (br *blockReader) unpaddedSize() int64 {
293	n := int64(br.headerLen)
294	n += br.compressedSize()
295	n += int64(br.hash.Size())
296	return n
297}
298
299// record returns the index record for the current block.
300func (br *blockReader) record() record {
301	return record{br.unpaddedSize(), br.uncompressedSize()}
302}
303
304// Read reads data from the block.
305func (br *blockReader) Read(p []byte) (n int, err error) {
306	n, err = br.r.Read(p)
307	br.n += int64(n)
308
309	u := br.header.uncompressedSize
310	if u >= 0 && br.uncompressedSize() > u {
311		return n, errors.New("xz: wrong uncompressed size for block")
312	}
313	c := br.header.compressedSize
314	if c >= 0 && br.compressedSize() > c {
315		return n, errors.New("xz: wrong compressed size for block")
316	}
317	if err != io.EOF {
318		return n, err
319	}
320	if br.uncompressedSize() < u || br.compressedSize() < c {
321		return n, io.ErrUnexpectedEOF
322	}
323
324	s := br.hash.Size()
325	k := padLen(br.lxz.n)
326	q := make([]byte, k+s, k+2*s)
327	if _, err = io.ReadFull(br.lxz.r, q); err != nil {
328		if err == io.EOF {
329			err = io.ErrUnexpectedEOF
330		}
331		return n, err
332	}
333	if !allZeros(q[:k]) {
334		return n, errors.New("xz: non-zero block padding")
335	}
336	checkSum := q[k:]
337	computedSum := br.hash.Sum(checkSum[s:])
338	if !bytes.Equal(checkSum, computedSum) {
339		return n, errors.New("xz: checksum error for block")
340	}
341	return n, io.EOF
342}
343
344func (c *ReaderConfig) newFilterReader(r io.Reader, f []filter) (fr io.Reader,
345	err error) {
346
347	if err = verifyFilters(f); err != nil {
348		return nil, err
349	}
350
351	fr = r
352	for i := len(f) - 1; i >= 0; i-- {
353		fr, err = f[i].reader(fr, c)
354		if err != nil {
355			return nil, err
356		}
357	}
358	return fr, nil
359}
360