1// Copyright 2014-2017 Ulrich Kunitz. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package xz
6
7import (
8	"bytes"
9	"crypto/sha256"
10	"errors"
11	"fmt"
12	"hash"
13	"hash/crc32"
14	"io"
15
16	"github.com/ulikunitz/xz/lzma"
17)
18
19// allZeros checks whether a given byte slice has only zeros.
20func allZeros(p []byte) bool {
21	for _, c := range p {
22		if c != 0 {
23			return false
24		}
25	}
26	return true
27}
28
29// padLen returns the length of the padding required for the given
30// argument.
31func padLen(n int64) int {
32	k := int(n % 4)
33	if k > 0 {
34		k = 4 - k
35	}
36	return k
37}
38
39/*** Header ***/
40
41// headerMagic stores the magic bytes for the header
42var headerMagic = []byte{0xfd, '7', 'z', 'X', 'Z', 0x00}
43
44// HeaderLen provides the length of the xz file header.
45const HeaderLen = 12
46
47// Constants for the checksum methods supported by xz.
48const (
49	CRC32  byte = 0x1
50	CRC64       = 0x4
51	SHA256      = 0xa
52)
53
54// errInvalidFlags indicates that flags are invalid.
55var errInvalidFlags = errors.New("xz: invalid flags")
56
57// verifyFlags returns the error errInvalidFlags if the value is
58// invalid.
59func verifyFlags(flags byte) error {
60	switch flags {
61	case CRC32, CRC64, SHA256:
62		return nil
63	default:
64		return errInvalidFlags
65	}
66}
67
68// flagstrings maps flag values to strings.
69var flagstrings = map[byte]string{
70	CRC32:  "CRC-32",
71	CRC64:  "CRC-64",
72	SHA256: "SHA-256",
73}
74
75// flagString returns the string representation for the given flags.
76func flagString(flags byte) string {
77	s, ok := flagstrings[flags]
78	if !ok {
79		return "invalid"
80	}
81	return s
82}
83
84// newHashFunc returns a function that creates hash instances for the
85// hash method encoded in flags.
86func newHashFunc(flags byte) (newHash func() hash.Hash, err error) {
87	switch flags {
88	case CRC32:
89		newHash = newCRC32
90	case CRC64:
91		newHash = newCRC64
92	case SHA256:
93		newHash = sha256.New
94	default:
95		err = errInvalidFlags
96	}
97	return
98}
99
100// header provides the actual content of the xz file header: the flags.
101type header struct {
102	flags byte
103}
104
105// Errors returned by readHeader.
106var errHeaderMagic = errors.New("xz: invalid header magic bytes")
107
108// ValidHeader checks whether data is a correct xz file header. The
109// length of data must be HeaderLen.
110func ValidHeader(data []byte) bool {
111	var h header
112	err := h.UnmarshalBinary(data)
113	return err == nil
114}
115
116// String returns a string representation of the flags.
117func (h header) String() string {
118	return flagString(h.flags)
119}
120
121// UnmarshalBinary reads header from the provided data slice.
122func (h *header) UnmarshalBinary(data []byte) error {
123	// header length
124	if len(data) != HeaderLen {
125		return errors.New("xz: wrong file header length")
126	}
127
128	// magic header
129	if !bytes.Equal(headerMagic, data[:6]) {
130		return errHeaderMagic
131	}
132
133	// checksum
134	crc := crc32.NewIEEE()
135	crc.Write(data[6:8])
136	if uint32LE(data[8:]) != crc.Sum32() {
137		return errors.New("xz: invalid checksum for file header")
138	}
139
140	// stream flags
141	if data[6] != 0 {
142		return errInvalidFlags
143	}
144	flags := data[7]
145	if err := verifyFlags(flags); err != nil {
146		return err
147	}
148
149	h.flags = flags
150	return nil
151}
152
153// MarshalBinary generates the xz file header.
154func (h *header) MarshalBinary() (data []byte, err error) {
155	if err = verifyFlags(h.flags); err != nil {
156		return nil, err
157	}
158
159	data = make([]byte, 12)
160	copy(data, headerMagic)
161	data[7] = h.flags
162
163	crc := crc32.NewIEEE()
164	crc.Write(data[6:8])
165	putUint32LE(data[8:], crc.Sum32())
166
167	return data, nil
168}
169
170/*** Footer ***/
171
172// footerLen defines the length of the footer.
173const footerLen = 12
174
175// footerMagic contains the footer magic bytes.
176var footerMagic = []byte{'Y', 'Z'}
177
178// footer represents the content of the xz file footer.
179type footer struct {
180	indexSize int64
181	flags     byte
182}
183
184// String prints a string representation of the footer structure.
185func (f footer) String() string {
186	return fmt.Sprintf("%s index size %d", flagString(f.flags), f.indexSize)
187}
188
189// Minimum and maximum for the size of the index (backward size).
190const (
191	minIndexSize = 4
192	maxIndexSize = (1 << 32) * 4
193)
194
195// MarshalBinary converts footer values into an xz file footer. Note
196// that the footer value is checked for correctness.
197func (f *footer) MarshalBinary() (data []byte, err error) {
198	if err = verifyFlags(f.flags); err != nil {
199		return nil, err
200	}
201	if !(minIndexSize <= f.indexSize && f.indexSize <= maxIndexSize) {
202		return nil, errors.New("xz: index size out of range")
203	}
204	if f.indexSize%4 != 0 {
205		return nil, errors.New(
206			"xz: index size not aligned to four bytes")
207	}
208
209	data = make([]byte, footerLen)
210
211	// backward size (index size)
212	s := (f.indexSize / 4) - 1
213	putUint32LE(data[4:], uint32(s))
214	// flags
215	data[9] = f.flags
216	// footer magic
217	copy(data[10:], footerMagic)
218
219	// CRC-32
220	crc := crc32.NewIEEE()
221	crc.Write(data[4:10])
222	putUint32LE(data, crc.Sum32())
223
224	return data, nil
225}
226
227// UnmarshalBinary sets the footer value by unmarshalling an xz file
228// footer.
229func (f *footer) UnmarshalBinary(data []byte) error {
230	if len(data) != footerLen {
231		return errors.New("xz: wrong footer length")
232	}
233
234	// magic bytes
235	if !bytes.Equal(data[10:], footerMagic) {
236		return errors.New("xz: footer magic invalid")
237	}
238
239	// CRC-32
240	crc := crc32.NewIEEE()
241	crc.Write(data[4:10])
242	if uint32LE(data) != crc.Sum32() {
243		return errors.New("xz: footer checksum error")
244	}
245
246	var g footer
247	// backward size (index size)
248	g.indexSize = (int64(uint32LE(data[4:])) + 1) * 4
249
250	// flags
251	if data[8] != 0 {
252		return errInvalidFlags
253	}
254	g.flags = data[9]
255	if err := verifyFlags(g.flags); err != nil {
256		return err
257	}
258
259	*f = g
260	return nil
261}
262
263/*** Block Header ***/
264
265// blockHeader represents the content of an xz block header.
266type blockHeader struct {
267	compressedSize   int64
268	uncompressedSize int64
269	filters          []filter
270}
271
272// String converts the block header into a string.
273func (h blockHeader) String() string {
274	var buf bytes.Buffer
275	first := true
276	if h.compressedSize >= 0 {
277		fmt.Fprintf(&buf, "compressed size %d", h.compressedSize)
278		first = false
279	}
280	if h.uncompressedSize >= 0 {
281		if !first {
282			buf.WriteString(" ")
283		}
284		fmt.Fprintf(&buf, "uncompressed size %d", h.uncompressedSize)
285		first = false
286	}
287	for _, f := range h.filters {
288		if !first {
289			buf.WriteString(" ")
290		}
291		fmt.Fprintf(&buf, "filter %s", f)
292		first = false
293	}
294	return buf.String()
295}
296
297// Masks for the block flags.
298const (
299	filterCountMask         = 0x03
300	compressedSizePresent   = 0x40
301	uncompressedSizePresent = 0x80
302	reservedBlockFlags      = 0x3C
303)
304
305// errIndexIndicator signals that an index indicator (0x00) has been found
306// instead of an expected block header indicator.
307var errIndexIndicator = errors.New("xz: found index indicator")
308
309// readBlockHeader reads the block header.
310func readBlockHeader(r io.Reader) (h *blockHeader, n int, err error) {
311	var buf bytes.Buffer
312	buf.Grow(20)
313
314	// block header size
315	z, err := io.CopyN(&buf, r, 1)
316	n = int(z)
317	if err != nil {
318		return nil, n, err
319	}
320	s := buf.Bytes()[0]
321	if s == 0 {
322		return nil, n, errIndexIndicator
323	}
324
325	// read complete header
326	headerLen := (int(s) + 1) * 4
327	buf.Grow(headerLen - 1)
328	z, err = io.CopyN(&buf, r, int64(headerLen-1))
329	n += int(z)
330	if err != nil {
331		return nil, n, err
332	}
333
334	// unmarshal block header
335	h = new(blockHeader)
336	if err = h.UnmarshalBinary(buf.Bytes()); err != nil {
337		return nil, n, err
338	}
339
340	return h, n, nil
341}
342
343// readSizeInBlockHeader reads the uncompressed or compressed size
344// fields in the block header. The present value informs the function
345// whether the respective field is actually present in the header.
346func readSizeInBlockHeader(r io.ByteReader, present bool) (n int64, err error) {
347	if !present {
348		return -1, nil
349	}
350	x, _, err := readUvarint(r)
351	if err != nil {
352		return 0, err
353	}
354	if x >= 1<<63 {
355		return 0, errors.New("xz: size overflow in block header")
356	}
357	return int64(x), nil
358}
359
360// UnmarshalBinary unmarshals the block header.
361func (h *blockHeader) UnmarshalBinary(data []byte) error {
362	// Check header length
363	s := data[0]
364	if data[0] == 0 {
365		return errIndexIndicator
366	}
367	headerLen := (int(s) + 1) * 4
368	if len(data) != headerLen {
369		return fmt.Errorf("xz: data length %d; want %d", len(data),
370			headerLen)
371	}
372	n := headerLen - 4
373
374	// Check CRC-32
375	crc := crc32.NewIEEE()
376	crc.Write(data[:n])
377	if crc.Sum32() != uint32LE(data[n:]) {
378		return errors.New("xz: checksum error for block header")
379	}
380
381	// Block header flags
382	flags := data[1]
383	if flags&reservedBlockFlags != 0 {
384		return errors.New("xz: reserved block header flags set")
385	}
386
387	r := bytes.NewReader(data[2:n])
388
389	// Compressed size
390	var err error
391	h.compressedSize, err = readSizeInBlockHeader(
392		r, flags&compressedSizePresent != 0)
393	if err != nil {
394		return err
395	}
396
397	// Uncompressed size
398	h.uncompressedSize, err = readSizeInBlockHeader(
399		r, flags&uncompressedSizePresent != 0)
400	if err != nil {
401		return err
402	}
403
404	h.filters, err = readFilters(r, int(flags&filterCountMask)+1)
405	if err != nil {
406		return err
407	}
408
409	// Check padding
410	// Since headerLen is a multiple of 4 we don't need to check
411	// alignment.
412	k := r.Len()
413	// The standard spec says that the padding should have not more
414	// than 3 bytes. However we found paddings of 4 or 5 in the
415	// wild. See https://github.com/ulikunitz/xz/pull/11 and
416	// https://github.com/ulikunitz/xz/issues/15
417	//
418	// The only reasonable approach seems to be to ignore the
419	// padding size. We still check that all padding bytes are zero.
420	if !allZeros(data[n-k : n]) {
421		return errPadding
422	}
423	return nil
424}
425
426// MarshalBinary marshals the binary header.
427func (h *blockHeader) MarshalBinary() (data []byte, err error) {
428	if !(minFilters <= len(h.filters) && len(h.filters) <= maxFilters) {
429		return nil, errors.New("xz: filter count wrong")
430	}
431	for i, f := range h.filters {
432		if i < len(h.filters)-1 {
433			if f.id() == lzmaFilterID {
434				return nil, errors.New(
435					"xz: LZMA2 filter is not the last")
436			}
437		} else {
438			// last filter
439			if f.id() != lzmaFilterID {
440				return nil, errors.New("xz: " +
441					"last filter must be the LZMA2 filter")
442			}
443		}
444	}
445
446	var buf bytes.Buffer
447	// header size must set at the end
448	buf.WriteByte(0)
449
450	// flags
451	flags := byte(len(h.filters) - 1)
452	if h.compressedSize >= 0 {
453		flags |= compressedSizePresent
454	}
455	if h.uncompressedSize >= 0 {
456		flags |= uncompressedSizePresent
457	}
458	buf.WriteByte(flags)
459
460	p := make([]byte, 10)
461	if h.compressedSize >= 0 {
462		k := putUvarint(p, uint64(h.compressedSize))
463		buf.Write(p[:k])
464	}
465	if h.uncompressedSize >= 0 {
466		k := putUvarint(p, uint64(h.uncompressedSize))
467		buf.Write(p[:k])
468	}
469
470	for _, f := range h.filters {
471		fp, err := f.MarshalBinary()
472		if err != nil {
473			return nil, err
474		}
475		buf.Write(fp)
476	}
477
478	// padding
479	for i := padLen(int64(buf.Len())); i > 0; i-- {
480		buf.WriteByte(0)
481	}
482
483	// crc place holder
484	buf.Write(p[:4])
485
486	data = buf.Bytes()
487	if len(data)%4 != 0 {
488		panic("data length not aligned")
489	}
490	s := len(data)/4 - 1
491	if !(1 < s && s <= 255) {
492		panic("wrong block header size")
493	}
494	data[0] = byte(s)
495
496	crc := crc32.NewIEEE()
497	crc.Write(data[:len(data)-4])
498	putUint32LE(data[len(data)-4:], crc.Sum32())
499
500	return data, nil
501}
502
503// Constants used for marshalling and unmarshalling filters in the xz
504// block header.
505const (
506	minFilters    = 1
507	maxFilters    = 4
508	minReservedID = 1 << 62
509)
510
511// filter represents a filter in the block header.
512type filter interface {
513	id() uint64
514	UnmarshalBinary(data []byte) error
515	MarshalBinary() (data []byte, err error)
516	reader(r io.Reader, c *ReaderConfig) (fr io.Reader, err error)
517	writeCloser(w io.WriteCloser, c *WriterConfig) (fw io.WriteCloser, err error)
518	// filter must be last filter
519	last() bool
520}
521
522// readFilter reads a block filter from the block header. At this point
523// in time only the LZMA2 filter is supported.
524func readFilter(r io.Reader) (f filter, err error) {
525	br := lzma.ByteReader(r)
526
527	// index
528	id, _, err := readUvarint(br)
529	if err != nil {
530		return nil, err
531	}
532
533	var data []byte
534	switch id {
535	case lzmaFilterID:
536		data = make([]byte, lzmaFilterLen)
537		data[0] = lzmaFilterID
538		if _, err = io.ReadFull(r, data[1:]); err != nil {
539			return nil, err
540		}
541		f = new(lzmaFilter)
542	default:
543		if id >= minReservedID {
544			return nil, errors.New(
545				"xz: reserved filter id in block stream header")
546		}
547		return nil, errors.New("xz: invalid filter id")
548	}
549	if err = f.UnmarshalBinary(data); err != nil {
550		return nil, err
551	}
552	return f, err
553}
554
555// readFilters reads count filters. At this point in time only the count
556// 1 is supported.
557func readFilters(r io.Reader, count int) (filters []filter, err error) {
558	if count != 1 {
559		return nil, errors.New("xz: unsupported filter count")
560	}
561	f, err := readFilter(r)
562	if err != nil {
563		return nil, err
564	}
565	return []filter{f}, err
566}
567
568// writeFilters writes the filters.
569func writeFilters(w io.Writer, filters []filter) (n int, err error) {
570	for _, f := range filters {
571		p, err := f.MarshalBinary()
572		if err != nil {
573			return n, err
574		}
575		k, err := w.Write(p)
576		n += k
577		if err != nil {
578			return n, err
579		}
580	}
581	return n, nil
582}
583
584/*** Index ***/
585
586// record describes a block in the xz file index.
587type record struct {
588	unpaddedSize     int64
589	uncompressedSize int64
590}
591
592// readRecord reads an index record.
593func readRecord(r io.ByteReader) (rec record, n int, err error) {
594	u, k, err := readUvarint(r)
595	n += k
596	if err != nil {
597		return rec, n, err
598	}
599	rec.unpaddedSize = int64(u)
600	if rec.unpaddedSize < 0 {
601		return rec, n, errors.New("xz: unpadded size negative")
602	}
603
604	u, k, err = readUvarint(r)
605	n += k
606	if err != nil {
607		return rec, n, err
608	}
609	rec.uncompressedSize = int64(u)
610	if rec.uncompressedSize < 0 {
611		return rec, n, errors.New("xz: uncompressed size negative")
612	}
613
614	return rec, n, nil
615}
616
617// MarshalBinary converts an index record in its binary encoding.
618func (rec *record) MarshalBinary() (data []byte, err error) {
619	// maximum length of a uvarint is 10
620	p := make([]byte, 20)
621	n := putUvarint(p, uint64(rec.unpaddedSize))
622	n += putUvarint(p[n:], uint64(rec.uncompressedSize))
623	return p[:n], nil
624}
625
626// writeIndex writes the index, a sequence of records.
627func writeIndex(w io.Writer, index []record) (n int64, err error) {
628	crc := crc32.NewIEEE()
629	mw := io.MultiWriter(w, crc)
630
631	// index indicator
632	k, err := mw.Write([]byte{0})
633	n += int64(k)
634	if err != nil {
635		return n, err
636	}
637
638	// number of records
639	p := make([]byte, 10)
640	k = putUvarint(p, uint64(len(index)))
641	k, err = mw.Write(p[:k])
642	n += int64(k)
643	if err != nil {
644		return n, err
645	}
646
647	// list of records
648	for _, rec := range index {
649		p, err := rec.MarshalBinary()
650		if err != nil {
651			return n, err
652		}
653		k, err = mw.Write(p)
654		n += int64(k)
655		if err != nil {
656			return n, err
657		}
658	}
659
660	// index padding
661	k, err = mw.Write(make([]byte, padLen(int64(n))))
662	n += int64(k)
663	if err != nil {
664		return n, err
665	}
666
667	// crc32 checksum
668	putUint32LE(p, crc.Sum32())
669	k, err = w.Write(p[:4])
670	n += int64(k)
671
672	return n, err
673}
674
675// readIndexBody reads the index from the reader. It assumes that the
676// index indicator has already been read.
677func readIndexBody(r io.Reader) (records []record, n int64, err error) {
678	crc := crc32.NewIEEE()
679	// index indicator
680	crc.Write([]byte{0})
681
682	br := lzma.ByteReader(io.TeeReader(r, crc))
683
684	// number of records
685	u, k, err := readUvarint(br)
686	n += int64(k)
687	if err != nil {
688		return nil, n, err
689	}
690	recLen := int(u)
691	if recLen < 0 || uint64(recLen) != u {
692		return nil, n, errors.New("xz: record number overflow")
693	}
694
695	// list of records
696	records = make([]record, recLen)
697	for i := range records {
698		records[i], k, err = readRecord(br)
699		n += int64(k)
700		if err != nil {
701			return nil, n, err
702		}
703	}
704
705	p := make([]byte, padLen(int64(n+1)), 4)
706	k, err = io.ReadFull(br.(io.Reader), p)
707	n += int64(k)
708	if err != nil {
709		return nil, n, err
710	}
711	if !allZeros(p) {
712		return nil, n, errors.New("xz: non-zero byte in index padding")
713	}
714
715	// crc32
716	s := crc.Sum32()
717	p = p[:4]
718	k, err = io.ReadFull(br.(io.Reader), p)
719	n += int64(k)
720	if err != nil {
721		return records, n, err
722	}
723	if uint32LE(p) != s {
724		return nil, n, errors.New("xz: wrong checksum for index")
725	}
726
727	return records, n, nil
728}
729