1// Copyright 2015, Joe Tsai. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE.md file.
4
5package brotli
6
7import (
8	"bufio"
9	"io"
10
11	"github.com/dsnet/compress/internal/errors"
12)
13
14// The bitReader preserves the property that it will never read more bytes than
15// is necessary. However, this feature dramatically hurts performance because
16// every byte needs to be obtained through a ReadByte method call.
17// Furthermore, the decoding of variable length codes in ReadSymbol, often
18// requires multiple passes before it knows the exact bit-length of the code.
19//
20// Thus, to improve performance, if the underlying byteReader is a bufio.Reader,
21// then the bitReader will use the Peek and Discard methods to fill the internal
22// bit buffer with as many bits as possible, allowing the TryReadBits and
23// TryReadSymbol methods to often succeed on the first try.
24
25type byteReader interface {
26	io.Reader
27	io.ByteReader
28}
29
30type bitReader struct {
31	rd      byteReader
32	bufBits uint64 // Buffer to hold some bits
33	numBits uint   // Number of valid bits in bufBits
34	offset  int64  // Number of bytes read from the underlying io.Reader
35
36	// These fields are only used if rd is a bufio.Reader.
37	bufRd       *bufio.Reader
38	bufPeek     []byte // Buffer for the Peek data
39	discardBits int    // Number of bits to discard from bufio.Reader
40	fedBits     uint   // Number of bits fed in last call to FeedBits
41
42	// Local copy of decoders to reduce memory allocations.
43	prefix prefixDecoder
44}
45
46func (br *bitReader) Init(r io.Reader) {
47	*br = bitReader{prefix: br.prefix}
48	if rr, ok := r.(byteReader); ok {
49		br.rd = rr
50	} else {
51		br.rd = bufio.NewReader(r)
52	}
53	if brd, ok := br.rd.(*bufio.Reader); ok {
54		br.bufRd = brd
55	}
56}
57
58// FlushOffset updates the read offset of the underlying byteReader.
59// If the byteReader is a bufio.Reader, then this calls Discard to update the
60// read offset.
61func (br *bitReader) FlushOffset() int64 {
62	if br.bufRd == nil {
63		return br.offset
64	}
65
66	// Update the number of total bits to discard.
67	br.discardBits += int(br.fedBits - br.numBits)
68	br.fedBits = br.numBits
69
70	// Discard some bytes to update read offset.
71	nd := (br.discardBits + 7) / 8 // Round up to nearest byte
72	nd, _ = br.bufRd.Discard(nd)
73	br.discardBits -= nd * 8 // -7..0
74	br.offset += int64(nd)
75
76	// These are invalid after Discard.
77	br.bufPeek = nil
78	return br.offset
79}
80
81// FeedBits ensures that at least nb bits exist in the bit buffer.
82// If the underlying byteReader is a bufio.Reader, then this will fill the
83// bit buffer with as many bits as possible, relying on Peek and Discard to
84// properly advance the read offset. Otherwise, it will use ReadByte to fill the
85// buffer with just the right number of bits.
86func (br *bitReader) FeedBits(nb uint) {
87	if br.bufRd != nil {
88		br.discardBits += int(br.fedBits - br.numBits)
89		for {
90			if len(br.bufPeek) == 0 {
91				br.fedBits = br.numBits // Don't discard bits just added
92				br.FlushOffset()
93
94				var err error
95				cntPeek := 8 // Minimum Peek amount to make progress
96				if br.bufRd.Buffered() > cntPeek {
97					cntPeek = br.bufRd.Buffered()
98				}
99				br.bufPeek, err = br.bufRd.Peek(cntPeek)
100				br.bufPeek = br.bufPeek[int(br.numBits/8):] // Skip buffered bits
101				if len(br.bufPeek) == 0 {
102					if br.numBits >= nb {
103						break
104					}
105					if err == io.EOF {
106						err = io.ErrUnexpectedEOF
107					}
108					errors.Panic(err)
109				}
110			}
111			cnt := int(64-br.numBits) / 8
112			if cnt > len(br.bufPeek) {
113				cnt = len(br.bufPeek)
114			}
115			for _, c := range br.bufPeek[:cnt] {
116				br.bufBits |= uint64(c) << br.numBits
117				br.numBits += 8
118			}
119			br.bufPeek = br.bufPeek[cnt:]
120			if br.numBits > 56 {
121				break
122			}
123		}
124		br.fedBits = br.numBits
125	} else {
126		for br.numBits < nb {
127			c, err := br.rd.ReadByte()
128			if err != nil {
129				if err == io.EOF {
130					err = io.ErrUnexpectedEOF
131				}
132				errors.Panic(err)
133			}
134			br.bufBits |= uint64(c) << br.numBits
135			br.numBits += 8
136			br.offset++
137		}
138	}
139}
140
141// Read reads up to len(buf) bytes into buf.
142func (br *bitReader) Read(buf []byte) (cnt int, err error) {
143	if br.numBits%8 != 0 {
144		return 0, errorf(errors.Invalid, "non-aligned bit buffer")
145	}
146	if br.numBits > 0 {
147		for cnt = 0; len(buf) > cnt && br.numBits > 0; cnt++ {
148			buf[cnt] = byte(br.bufBits)
149			br.bufBits >>= 8
150			br.numBits -= 8
151		}
152	} else {
153		br.FlushOffset()
154		cnt, err = br.rd.Read(buf)
155		br.offset += int64(cnt)
156	}
157	return cnt, err
158}
159
160// TryReadBits attempts to read nb bits using the contents of the bit buffer
161// alone. It returns the value and whether it succeeded.
162//
163// This method is designed to be inlined for performance reasons.
164func (br *bitReader) TryReadBits(nb uint) (uint, bool) {
165	if br.numBits < nb {
166		return 0, false
167	}
168	val := uint(br.bufBits & uint64(1<<nb-1))
169	br.bufBits >>= nb
170	br.numBits -= nb
171	return val, true
172}
173
174// ReadBits reads nb bits in LSB order from the underlying reader.
175func (br *bitReader) ReadBits(nb uint) uint {
176	br.FeedBits(nb)
177	val := uint(br.bufBits & uint64(1<<nb-1))
178	br.bufBits >>= nb
179	br.numBits -= nb
180	return val
181}
182
183// ReadPads reads 0-7 bits from the bit buffer to achieve byte-alignment.
184func (br *bitReader) ReadPads() uint {
185	nb := br.numBits % 8
186	val := uint(br.bufBits & uint64(1<<nb-1))
187	br.bufBits >>= nb
188	br.numBits -= nb
189	return val
190}
191
192// TryReadSymbol attempts to decode the next symbol using the contents of the
193// bit buffer alone. It returns the decoded symbol and whether it succeeded.
194//
195// This method is designed to be inlined for performance reasons.
196func (br *bitReader) TryReadSymbol(pd *prefixDecoder) (uint, bool) {
197	if br.numBits < uint(pd.minBits) || len(pd.chunks) == 0 {
198		return 0, false
199	}
200	chunk := pd.chunks[uint32(br.bufBits)&pd.chunkMask]
201	nb := uint(chunk & prefixCountMask)
202	if nb > br.numBits || nb > uint(pd.chunkBits) {
203		return 0, false
204	}
205	br.bufBits >>= nb
206	br.numBits -= nb
207	return uint(chunk >> prefixCountBits), true
208}
209
210// ReadSymbol reads the next prefix symbol using the provided prefixDecoder.
211func (br *bitReader) ReadSymbol(pd *prefixDecoder) uint {
212	if len(pd.chunks) == 0 {
213		errors.Panic(errInvalid) // Decode with empty tree
214	}
215
216	nb := uint(pd.minBits)
217	for {
218		br.FeedBits(nb)
219		chunk := pd.chunks[uint32(br.bufBits)&pd.chunkMask]
220		nb = uint(chunk & prefixCountMask)
221		if nb > uint(pd.chunkBits) {
222			linkIdx := chunk >> prefixCountBits
223			chunk = pd.links[linkIdx][uint32(br.bufBits>>pd.chunkBits)&pd.linkMask]
224			nb = uint(chunk & prefixCountMask)
225		}
226		if nb <= br.numBits {
227			br.bufBits >>= nb
228			br.numBits -= nb
229			return uint(chunk >> prefixCountBits)
230		}
231	}
232}
233
234// ReadOffset reads an offset value using the provided rangesCodes indexed by
235// the given symbol.
236func (br *bitReader) ReadOffset(sym uint, rcs []rangeCode) uint {
237	rc := rcs[sym]
238	return uint(rc.base) + br.ReadBits(uint(rc.bits))
239}
240
241// ReadPrefixCode reads the prefix definition from the stream and initializes
242// the provided prefixDecoder. The value maxSyms is the alphabet size of the
243// prefix code being generated. The actual number of representable symbols
244// will be between 1 and maxSyms, inclusively.
245func (br *bitReader) ReadPrefixCode(pd *prefixDecoder, maxSyms uint) {
246	hskip := br.ReadBits(2)
247	if hskip == 1 {
248		br.readSimplePrefixCode(pd, maxSyms)
249	} else {
250		br.readComplexPrefixCode(pd, maxSyms, hskip)
251	}
252}
253
254// readSimplePrefixCode reads the prefix code according to RFC section 3.4.
255func (br *bitReader) readSimplePrefixCode(pd *prefixDecoder, maxSyms uint) {
256	var codes [4]prefixCode
257	nsym := int(br.ReadBits(2)) + 1
258	clen := neededBits(uint32(maxSyms))
259	for i := 0; i < nsym; i++ {
260		codes[i].sym = uint32(br.ReadBits(clen))
261	}
262
263	copyLens := func(lens []uint) {
264		for i := 0; i < nsym; i++ {
265			codes[i].len = uint32(lens[i])
266		}
267	}
268	compareSwap := func(i, j int) {
269		if codes[i].sym > codes[j].sym {
270			codes[i], codes[j] = codes[j], codes[i]
271		}
272	}
273
274	switch nsym {
275	case 1:
276		copyLens(simpleLens1[:])
277	case 2:
278		copyLens(simpleLens2[:])
279		compareSwap(0, 1)
280	case 3:
281		copyLens(simpleLens3[:])
282		compareSwap(0, 1)
283		compareSwap(0, 2)
284		compareSwap(1, 2)
285	case 4:
286		if tsel := br.ReadBits(1) == 1; !tsel {
287			copyLens(simpleLens4a[:])
288		} else {
289			copyLens(simpleLens4b[:])
290		}
291		compareSwap(0, 1)
292		compareSwap(2, 3)
293		compareSwap(0, 2)
294		compareSwap(1, 3)
295		compareSwap(1, 2)
296	}
297	if uint(codes[nsym-1].sym) >= maxSyms {
298		errors.Panic(errCorrupted) // Symbol goes beyond range of alphabet
299	}
300	pd.Init(codes[:nsym], true) // Must have 1..4 symbols
301}
302
303// readComplexPrefixCode reads the prefix code according to RFC section 3.5.
304func (br *bitReader) readComplexPrefixCode(pd *prefixDecoder, maxSyms, hskip uint) {
305	// Read the code-lengths prefix table.
306	var codeCLensArr [len(complexLens)]prefixCode // Sorted, but may have holes
307	sum := 32
308	for _, sym := range complexLens[hskip:] {
309		clen := br.ReadSymbol(&decCLens)
310		if clen > 0 {
311			codeCLensArr[sym] = prefixCode{sym: uint32(sym), len: uint32(clen)}
312			if sum -= 32 >> clen; sum <= 0 {
313				break
314			}
315		}
316	}
317	codeCLens := codeCLensArr[:0] // Compact the array to have no holes
318	for _, c := range codeCLensArr {
319		if c.len > 0 {
320			codeCLens = append(codeCLens, c)
321		}
322	}
323	if len(codeCLens) < 1 {
324		errors.Panic(errCorrupted)
325	}
326	br.prefix.Init(codeCLens, true) // Must have 1..len(complexLens) symbols
327
328	// Use code-lengths table to decode rest of prefix table.
329	var codesArr [maxNumAlphabetSyms]prefixCode
330	var sym, repSymLast, repCntLast, clenLast uint = 0, 0, 0, 8
331	codes := codesArr[:0]
332	for sym, sum = 0, 32768; sym < maxSyms && sum > 0; {
333		clen := br.ReadSymbol(&br.prefix)
334		if clen < 16 {
335			// Literal bit-length symbol used.
336			if clen > 0 {
337				codes = append(codes, prefixCode{sym: uint32(sym), len: uint32(clen)})
338				clenLast = clen
339				sum -= 32768 >> clen
340			}
341			repSymLast = 0 // Reset last repeater symbol
342			sym++
343		} else {
344			// Repeater symbol used.
345			//	16: Repeat previous non-zero code-length
346			//	17: Repeat code length of zero
347
348			repSym := clen // Rename clen for better clarity
349			if repSym != repSymLast {
350				repCntLast = 0
351				repSymLast = repSym
352			}
353
354			nb := repSym - 14          // 2..3 bits
355			rep := br.ReadBits(nb) + 3 // 3..6 or 3..10
356			if repCntLast > 0 {
357				rep += (repCntLast - 2) << nb // Modify previous repeat count
358			}
359			repDiff := rep - repCntLast // Always positive
360			repCntLast = rep
361
362			if repSym == 16 {
363				clen := clenLast
364				for symEnd := sym + repDiff; sym < symEnd; sym++ {
365					codes = append(codes, prefixCode{sym: uint32(sym), len: uint32(clen)})
366				}
367				sum -= int(repDiff) * (32768 >> clen)
368			} else {
369				sym += repDiff
370			}
371		}
372	}
373	if len(codes) < 2 || sym > maxSyms {
374		errors.Panic(errCorrupted)
375	}
376	pd.Init(codes, true) // Must have 2..maxSyms symbols
377}
378