1// Copyright 2019+ Klaus Post. All rights reserved.
2// License information can be found in the LICENSE file.
3// Based on work by Yann Collet, released under BSD License.
4
5package zstd
6
7import (
8	"encoding/binary"
9	"fmt"
10	"io"
11	"math"
12	"math/bits"
13)
14
15type frameHeader struct {
16	ContentSize   uint64
17	WindowSize    uint32
18	SingleSegment bool
19	Checksum      bool
20	DictID        uint32
21}
22
23const maxHeaderSize = 14
24
25func (f frameHeader) appendTo(dst []byte) ([]byte, error) {
26	dst = append(dst, frameMagic...)
27	var fhd uint8
28	if f.Checksum {
29		fhd |= 1 << 2
30	}
31	if f.SingleSegment {
32		fhd |= 1 << 5
33	}
34
35	var dictIDContent []byte
36	if f.DictID > 0 {
37		var tmp [4]byte
38		if f.DictID < 256 {
39			fhd |= 1
40			tmp[0] = uint8(f.DictID)
41			dictIDContent = tmp[:1]
42		} else if f.DictID < 1<<16 {
43			fhd |= 2
44			binary.LittleEndian.PutUint16(tmp[:2], uint16(f.DictID))
45			dictIDContent = tmp[:2]
46		} else {
47			fhd |= 3
48			binary.LittleEndian.PutUint32(tmp[:4], f.DictID)
49			dictIDContent = tmp[:4]
50		}
51	}
52	var fcs uint8
53	if f.ContentSize >= 256 {
54		fcs++
55	}
56	if f.ContentSize >= 65536+256 {
57		fcs++
58	}
59	if f.ContentSize >= 0xffffffff {
60		fcs++
61	}
62
63	fhd |= fcs << 6
64
65	dst = append(dst, fhd)
66	if !f.SingleSegment {
67		const winLogMin = 10
68		windowLog := (bits.Len32(f.WindowSize-1) - winLogMin) << 3
69		dst = append(dst, uint8(windowLog))
70	}
71	if f.DictID > 0 {
72		dst = append(dst, dictIDContent...)
73	}
74	switch fcs {
75	case 0:
76		if f.SingleSegment {
77			dst = append(dst, uint8(f.ContentSize))
78		}
79		// Unless SingleSegment is set, framessizes < 256 are nto stored.
80	case 1:
81		f.ContentSize -= 256
82		dst = append(dst, uint8(f.ContentSize), uint8(f.ContentSize>>8))
83	case 2:
84		dst = append(dst, uint8(f.ContentSize), uint8(f.ContentSize>>8), uint8(f.ContentSize>>16), uint8(f.ContentSize>>24))
85	case 3:
86		dst = append(dst, uint8(f.ContentSize), uint8(f.ContentSize>>8), uint8(f.ContentSize>>16), uint8(f.ContentSize>>24),
87			uint8(f.ContentSize>>32), uint8(f.ContentSize>>40), uint8(f.ContentSize>>48), uint8(f.ContentSize>>56))
88	default:
89		panic("invalid fcs")
90	}
91	return dst, nil
92}
93
94const skippableFrameHeader = 4 + 4
95
96// calcSkippableFrame will return a total size to be added for written
97// to be divisible by multiple.
98// The value will always be > skippableFrameHeader.
99// The function will panic if written < 0 or wantMultiple <= 0.
100func calcSkippableFrame(written, wantMultiple int64) int {
101	if wantMultiple <= 0 {
102		panic("wantMultiple <= 0")
103	}
104	if written < 0 {
105		panic("written < 0")
106	}
107	leftOver := written % wantMultiple
108	if leftOver == 0 {
109		return 0
110	}
111	toAdd := wantMultiple - leftOver
112	for toAdd < skippableFrameHeader {
113		toAdd += wantMultiple
114	}
115	return int(toAdd)
116}
117
118// skippableFrame will add a skippable frame with a total size of bytes.
119// total should be >= skippableFrameHeader and < math.MaxUint32.
120func skippableFrame(dst []byte, total int, r io.Reader) ([]byte, error) {
121	if total == 0 {
122		return dst, nil
123	}
124	if total < skippableFrameHeader {
125		return dst, fmt.Errorf("requested skippable frame (%d) < 8", total)
126	}
127	if int64(total) > math.MaxUint32 {
128		return dst, fmt.Errorf("requested skippable frame (%d) > max uint32", total)
129	}
130	dst = append(dst, 0x50, 0x2a, 0x4d, 0x18)
131	f := uint32(total - skippableFrameHeader)
132	dst = append(dst, uint8(f), uint8(f>>8), uint8(f>>16), uint8(f>>24))
133	start := len(dst)
134	dst = append(dst, make([]byte, f)...)
135	_, err := io.ReadFull(r, dst[start:])
136	return dst, err
137}
138