1// Copyright 2015, Joe Tsai. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE.md file.
4
5package prefix
6
7import (
8	"bufio"
9	"bytes"
10	"encoding/binary"
11	"io"
12	"strings"
13
14	"github.com/dsnet/compress"
15	"github.com/dsnet/compress/internal"
16	"github.com/dsnet/compress/internal/errors"
17)
18
19// Reader implements a prefix decoder. If the input io.Reader satisfies the
20// compress.ByteReader or compress.BufferedReader interface, then it also
21// guarantees that it will never read more bytes than is necessary.
22//
23// For high performance, provide an io.Reader that satisfies the
24// compress.BufferedReader interface. If the input does not satisfy either
25// compress.ByteReader or compress.BufferedReader, then it will be internally
26// wrapped with a bufio.Reader.
27type Reader struct {
28	Offset int64 // Number of bytes read from the underlying io.Reader
29
30	rd     io.Reader
31	byteRd compress.ByteReader     // Set if rd is a ByteReader
32	bufRd  compress.BufferedReader // Set if rd is a BufferedReader
33
34	bufBits   uint64 // Buffer to hold some bits
35	numBits   uint   // Number of valid bits in bufBits
36	bigEndian bool   // Do we treat input bytes as big endian?
37
38	// These fields are only used if rd is a compress.BufferedReader.
39	bufPeek     []byte // Buffer for the Peek data
40	discardBits int    // Number of bits to discard from reader
41	fedBits     uint   // Number of bits fed in last call to PullBits
42
43	// These fields are used to reduce allocations.
44	bb *buffer
45	br *bytesReader
46	sr *stringReader
47	bu *bufio.Reader
48}
49
50// Init initializes the bit Reader to read from r. If bigEndian is true, then
51// bits will be read starting from the most-significant bits of a byte
52// (as done in bzip2), otherwise it will read starting from the
53// least-significant bits of a byte (such as for deflate and brotli).
54func (pr *Reader) Init(r io.Reader, bigEndian bool) {
55	*pr = Reader{
56		rd:        r,
57		bigEndian: bigEndian,
58
59		bb: pr.bb,
60		br: pr.br,
61		sr: pr.sr,
62		bu: pr.bu,
63	}
64	switch rr := r.(type) {
65	case *bytes.Buffer:
66		if pr.bb == nil {
67			pr.bb = new(buffer)
68		}
69		*pr.bb = buffer{Buffer: rr}
70		pr.bufRd = pr.bb
71	case *bytes.Reader:
72		if pr.br == nil {
73			pr.br = new(bytesReader)
74		}
75		*pr.br = bytesReader{Reader: rr}
76		pr.bufRd = pr.br
77	case *strings.Reader:
78		if pr.sr == nil {
79			pr.sr = new(stringReader)
80		}
81		*pr.sr = stringReader{Reader: rr}
82		pr.bufRd = pr.sr
83	case compress.BufferedReader:
84		pr.bufRd = rr
85	case compress.ByteReader:
86		pr.byteRd = rr
87	default:
88		if pr.bu == nil {
89			pr.bu = bufio.NewReader(nil)
90		}
91		pr.bu.Reset(r)
92		pr.rd, pr.bufRd = pr.bu, pr.bu
93	}
94}
95
96// BitsRead reports the total number of bits emitted from any Read method.
97func (pr *Reader) BitsRead() int64 {
98	offset := 8*pr.Offset - int64(pr.numBits)
99	if pr.bufRd != nil {
100		discardBits := pr.discardBits + int(pr.fedBits-pr.numBits)
101		offset = 8*pr.Offset + int64(discardBits)
102	}
103	return offset
104}
105
106// IsBufferedReader reports whether the underlying io.Reader is also a
107// compress.BufferedReader.
108func (pr *Reader) IsBufferedReader() bool {
109	return pr.bufRd != nil
110}
111
112// ReadPads reads 0-7 bits from the bit buffer to achieve byte-alignment.
113func (pr *Reader) ReadPads() uint {
114	nb := pr.numBits % 8
115	val := uint(pr.bufBits & uint64(1<<nb-1))
116	pr.bufBits >>= nb
117	pr.numBits -= nb
118	return val
119}
120
121// Read reads bytes into buf.
122// The bit-ordering mode does not affect this method.
123func (pr *Reader) Read(buf []byte) (cnt int, err error) {
124	if pr.numBits > 0 {
125		if pr.numBits%8 != 0 {
126			return 0, errorf(errors.Invalid, "non-aligned bit buffer")
127		}
128		for cnt = 0; len(buf) > cnt && pr.numBits > 0; cnt++ {
129			if pr.bigEndian {
130				buf[cnt] = internal.ReverseLUT[byte(pr.bufBits)]
131			} else {
132				buf[cnt] = byte(pr.bufBits)
133			}
134			pr.bufBits >>= 8
135			pr.numBits -= 8
136		}
137		return cnt, nil
138	}
139	if _, err := pr.Flush(); err != nil {
140		return 0, err
141	}
142	cnt, err = pr.rd.Read(buf)
143	pr.Offset += int64(cnt)
144	return cnt, err
145}
146
147// ReadOffset reads an offset value using the provided RangeCodes indexed by
148// the symbol read.
149func (pr *Reader) ReadOffset(pd *Decoder, rcs RangeCodes) uint {
150	rc := rcs[pr.ReadSymbol(pd)]
151	return uint(rc.Base) + pr.ReadBits(uint(rc.Len))
152}
153
154// TryReadBits attempts to read nb bits using the contents of the bit buffer
155// alone. It returns the value and whether it succeeded.
156//
157// This method is designed to be inlined for performance reasons.
158func (pr *Reader) TryReadBits(nb uint) (uint, bool) {
159	if pr.numBits < nb {
160		return 0, false
161	}
162	val := uint(pr.bufBits & uint64(1<<nb-1))
163	pr.bufBits >>= nb
164	pr.numBits -= nb
165	return val, true
166}
167
168// ReadBits reads nb bits in from the underlying reader.
169func (pr *Reader) ReadBits(nb uint) uint {
170	if err := pr.PullBits(nb); err != nil {
171		errors.Panic(err)
172	}
173	val := uint(pr.bufBits & uint64(1<<nb-1))
174	pr.bufBits >>= nb
175	pr.numBits -= nb
176	return val
177}
178
179// TryReadSymbol attempts to decode the next symbol using the contents of the
180// bit buffer alone. It returns the decoded symbol and whether it succeeded.
181//
182// This method is designed to be inlined for performance reasons.
183func (pr *Reader) TryReadSymbol(pd *Decoder) (uint, bool) {
184	if pr.numBits < uint(pd.MinBits) || len(pd.chunks) == 0 {
185		return 0, false
186	}
187	chunk := pd.chunks[uint32(pr.bufBits)&pd.chunkMask]
188	nb := uint(chunk & countMask)
189	if nb > pr.numBits || nb > uint(pd.chunkBits) {
190		return 0, false
191	}
192	pr.bufBits >>= nb
193	pr.numBits -= nb
194	return uint(chunk >> countBits), true
195}
196
197// ReadSymbol reads the next symbol using the provided prefix Decoder.
198func (pr *Reader) ReadSymbol(pd *Decoder) uint {
199	if len(pd.chunks) == 0 {
200		panicf(errors.Invalid, "decode with empty prefix tree")
201	}
202
203	nb := uint(pd.MinBits)
204	for {
205		if err := pr.PullBits(nb); err != nil {
206			errors.Panic(err)
207		}
208		chunk := pd.chunks[uint32(pr.bufBits)&pd.chunkMask]
209		nb = uint(chunk & countMask)
210		if nb > uint(pd.chunkBits) {
211			linkIdx := chunk >> countBits
212			chunk = pd.links[linkIdx][uint32(pr.bufBits>>pd.chunkBits)&pd.linkMask]
213			nb = uint(chunk & countMask)
214		}
215		if nb <= pr.numBits {
216			pr.bufBits >>= nb
217			pr.numBits -= nb
218			return uint(chunk >> countBits)
219		}
220	}
221}
222
223// Flush updates the read offset of the underlying ByteReader.
224// If reader is a compress.BufferedReader, then this calls Discard to update
225// the read offset.
226func (pr *Reader) Flush() (int64, error) {
227	if pr.bufRd == nil {
228		return pr.Offset, nil
229	}
230
231	// Update the number of total bits to discard.
232	pr.discardBits += int(pr.fedBits - pr.numBits)
233	pr.fedBits = pr.numBits
234
235	// Discard some bytes to update read offset.
236	var err error
237	nd := (pr.discardBits + 7) / 8 // Round up to nearest byte
238	nd, err = pr.bufRd.Discard(nd)
239	pr.discardBits -= nd * 8 // -7..0
240	pr.Offset += int64(nd)
241
242	// These are invalid after Discard.
243	pr.bufPeek = nil
244	return pr.Offset, err
245}
246
247// PullBits ensures that at least nb bits exist in the bit buffer.
248// If the underlying reader is a compress.BufferedReader, then this will fill
249// the bit buffer with as many bits as possible, relying on Peek and Discard to
250// properly advance the read offset. Otherwise, it will use ReadByte to fill the
251// buffer with just the right number of bits.
252func (pr *Reader) PullBits(nb uint) error {
253	if pr.bufRd != nil {
254		pr.discardBits += int(pr.fedBits - pr.numBits)
255		for {
256			if len(pr.bufPeek) == 0 {
257				pr.fedBits = pr.numBits // Don't discard bits just added
258				if _, err := pr.Flush(); err != nil {
259					return err
260				}
261
262				// Peek no more bytes than necessary.
263				// The computation for cntPeek computes the minimum number of
264				// bytes to Peek to fill nb bits.
265				var err error
266				cntPeek := int(nb+(-nb&7)) / 8
267				if cntPeek < pr.bufRd.Buffered() {
268					cntPeek = pr.bufRd.Buffered()
269				}
270				pr.bufPeek, err = pr.bufRd.Peek(cntPeek)
271				pr.bufPeek = pr.bufPeek[int(pr.numBits/8):] // Skip buffered bits
272				if len(pr.bufPeek) == 0 {
273					if pr.numBits >= nb {
274						break
275					}
276					if err == io.EOF {
277						err = io.ErrUnexpectedEOF
278					}
279					return err
280				}
281			}
282
283			n := int(64-pr.numBits) / 8 // Number of bytes to copy to bit buffer
284			if len(pr.bufPeek) >= 8 {
285				// Starting with Go 1.7, the compiler should use a wide integer
286				// load here if the architecture supports it.
287				u := binary.LittleEndian.Uint64(pr.bufPeek)
288				if pr.bigEndian {
289					// Swap all the bits within each byte.
290					u = (u&0xaaaaaaaaaaaaaaaa)>>1 | (u&0x5555555555555555)<<1
291					u = (u&0xcccccccccccccccc)>>2 | (u&0x3333333333333333)<<2
292					u = (u&0xf0f0f0f0f0f0f0f0)>>4 | (u&0x0f0f0f0f0f0f0f0f)<<4
293				}
294
295				pr.bufBits |= u << pr.numBits
296				pr.numBits += uint(n * 8)
297				pr.bufPeek = pr.bufPeek[n:]
298				break
299			} else {
300				if n > len(pr.bufPeek) {
301					n = len(pr.bufPeek)
302				}
303				for _, c := range pr.bufPeek[:n] {
304					if pr.bigEndian {
305						c = internal.ReverseLUT[c]
306					}
307					pr.bufBits |= uint64(c) << pr.numBits
308					pr.numBits += 8
309				}
310				pr.bufPeek = pr.bufPeek[n:]
311				if pr.numBits > 56 {
312					break
313				}
314			}
315		}
316		pr.fedBits = pr.numBits
317	} else {
318		for pr.numBits < nb {
319			c, err := pr.byteRd.ReadByte()
320			if err != nil {
321				if err == io.EOF {
322					err = io.ErrUnexpectedEOF
323				}
324				return err
325			}
326			if pr.bigEndian {
327				c = internal.ReverseLUT[c]
328			}
329			pr.bufBits |= uint64(c) << pr.numBits
330			pr.numBits += 8
331			pr.Offset++
332		}
333	}
334	return nil
335}
336