1// Copyright 2015, Joe Tsai. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE.md file.
4
5package brotli
6
7import (
8	"io"
9	"io/ioutil"
10
11	"github.com/dsnet/compress/internal"
12	"github.com/dsnet/compress/internal/errors"
13)
14
15type Reader struct {
16	InputOffset  int64 // Total number of bytes read from underlying io.Reader
17	OutputOffset int64 // Total number of bytes emitted from Read
18
19	rd     bitReader // Input source
20	toRead []byte    // Uncompressed data ready to be emitted from Read
21	blkLen int       // Uncompressed bytes left to read in meta-block
22	insLen int       // Bytes left to insert in current command
23	cpyLen int       // Bytes left to copy in current command
24	last   bool      // Last block bit detected
25	err    error     // Persistent error
26
27	step      func(*Reader) // Single step of decompression work (can panic)
28	stepState int           // The sub-step state for certain steps
29
30	mtf     internal.MoveToFront // Local move-to-front decoder
31	dict    dictDecoder          // Dynamic sliding dictionary
32	iacBlk  blockDecoder         // Insert-and-copy block decoder
33	litBlk  blockDecoder         // Literal block decoder
34	distBlk blockDecoder         // Distance block decoder
35
36	// Literal decoding state fields.
37	litMapType []uint8 // The current literal context map for the current block type
38	litMap     []uint8 // Literal context map
39	cmode      uint8   // The current context mode
40	cmodes     []uint8 // Literal context modes
41
42	// Distance decoding state fields.
43	distMap     []uint8 // Distance context map
44	distMapType []uint8 // The current distance context map for the current block type
45	dist        int     // The current distance (may not be in dists)
46	dists       [4]int  // Last few distances (newest-to-oldest)
47	distZero    bool    // Implicit zero distance symbol found
48	npostfix    uint8   // Postfix bits used in distance decoding
49	ndirect     uint8   // Number of direct distance codes
50
51	// Static dictionary state fields.
52	word    []byte            // Transformed word obtained from static dictionary
53	wordBuf [maxWordSize]byte // Buffer to write a transformed word into
54
55	// Meta data fields.
56	metaRd  io.LimitedReader // Local LimitedReader to reduce allocation
57	metaWr  io.Writer        // Writer to write meta data to
58	metaBuf []byte           // Scratch space for reading meta data
59}
60
61type blockDecoder struct {
62	numTypes int             // Total number of types
63	typeLen  int             // The number of blocks left for this type
64	types    [2]uint8        // The current (0) and previous (1) block type
65	decType  prefixDecoder   // Prefix decoder for the type symbol
66	decLen   prefixDecoder   // Prefix decoder for block length
67	prefixes []prefixDecoder // Prefix decoders for each block type
68}
69
70type ReaderConfig struct {
71	_ struct{} // Blank field to prevent unkeyed struct literals
72}
73
74func NewReader(r io.Reader, conf *ReaderConfig) (*Reader, error) {
75	br := new(Reader)
76	br.Reset(r)
77	return br, nil
78}
79
80func (br *Reader) Read(buf []byte) (int, error) {
81	for {
82		if len(br.toRead) > 0 {
83			cnt := copy(buf, br.toRead)
84			br.toRead = br.toRead[cnt:]
85			br.OutputOffset += int64(cnt)
86			return cnt, nil
87		}
88		if br.err != nil {
89			return 0, br.err
90		}
91
92		// Perform next step in decompression process.
93		br.rd.offset = br.InputOffset
94		func() {
95			defer errors.Recover(&br.err)
96			br.step(br)
97		}()
98		br.InputOffset = br.rd.FlushOffset()
99		if br.err != nil {
100			br.toRead = br.dict.ReadFlush() // Flush what's left in case of error
101		}
102	}
103}
104
105func (br *Reader) Close() error {
106	if br.err == io.EOF || br.err == io.ErrClosedPipe {
107		br.toRead = nil // Make sure future reads fail
108		br.err = io.ErrClosedPipe
109		return nil
110	}
111	return br.err // Return the persistent error
112}
113
114func (br *Reader) Reset(r io.Reader) error {
115	*br = Reader{
116		rd:   br.rd,
117		step: (*Reader).readStreamHeader,
118
119		dict:    br.dict,
120		iacBlk:  br.iacBlk,
121		litBlk:  br.litBlk,
122		distBlk: br.distBlk,
123		word:    br.word[:0],
124		cmodes:  br.cmodes[:0],
125		litMap:  br.litMap[:0],
126		distMap: br.distMap[:0],
127		dists:   [4]int{4, 11, 15, 16}, // RFC section 4
128
129		// TODO(dsnet): Should we write meta data somewhere useful?
130		metaWr:  ioutil.Discard,
131		metaBuf: br.metaBuf,
132	}
133	br.rd.Init(r)
134	return nil
135}
136
137// readStreamHeader reads the Brotli stream header according to RFC section 9.1.
138func (br *Reader) readStreamHeader() {
139	wbits := br.rd.ReadSymbol(&decWinBits)
140	if wbits == 0 {
141		errors.Panic(errCorrupted) // Reserved value used
142	}
143	size := int(1<<wbits) - 16
144	br.dict.Init(size)
145	br.readBlockHeader()
146}
147
148// readBlockHeader reads a meta-block header according to RFC section 9.2.
149func (br *Reader) readBlockHeader() {
150	if br.last {
151		if br.rd.ReadPads() > 0 {
152			errors.Panic(errCorrupted)
153		}
154		errors.Panic(io.EOF)
155	}
156
157	// Read ISLAST and ISLASTEMPTY.
158	if br.last = br.rd.ReadBits(1) == 1; br.last {
159		if empty := br.rd.ReadBits(1) == 1; empty {
160			br.readBlockHeader() // Next call will terminate stream
161			return
162		}
163	}
164
165	// Read MLEN and MNIBBLES and process meta data.
166	var blkLen int // 1..1<<24
167	nibbles := br.rd.ReadBits(2) + 4
168	if nibbles == 7 {
169		if reserved := br.rd.ReadBits(1) == 1; reserved {
170			errors.Panic(errCorrupted)
171		}
172
173		var skipLen int // 0..1<<24
174		if skipBytes := br.rd.ReadBits(2); skipBytes > 0 {
175			skipLen = int(br.rd.ReadBits(skipBytes * 8))
176			if skipBytes > 1 && skipLen>>((skipBytes-1)*8) == 0 {
177				errors.Panic(errCorrupted) // Shortest representation not used
178			}
179			skipLen++
180		}
181
182		if br.rd.ReadPads() > 0 {
183			errors.Panic(errCorrupted)
184		}
185		br.blkLen = skipLen // Use blkLen to track metadata number of bytes
186		br.readMetaData()
187		return
188	}
189	blkLen = int(br.rd.ReadBits(nibbles * 4))
190	if nibbles > 4 && blkLen>>((nibbles-1)*4) == 0 {
191		errors.Panic(errCorrupted) // Shortest representation not used
192	}
193	br.blkLen = blkLen + 1
194
195	// Read ISUNCOMPRESSED and process uncompressed data.
196	if !br.last {
197		if uncompressed := br.rd.ReadBits(1) == 1; uncompressed {
198			if br.rd.ReadPads() > 0 {
199				errors.Panic(errCorrupted)
200			}
201			br.readRawData()
202			return
203		}
204	}
205	br.readPrefixCodes()
206}
207
208// readMetaData reads meta data according to RFC section 9.2.
209func (br *Reader) readMetaData() {
210	br.metaRd.R = &br.rd
211	br.metaRd.N = int64(br.blkLen)
212	if br.metaBuf == nil {
213		br.metaBuf = make([]byte, 4096) // Lazy allocate
214	}
215	if cnt, err := io.CopyBuffer(br.metaWr, &br.metaRd, br.metaBuf); err != nil {
216		errors.Panic(err) // Will never panic with io.EOF
217	} else if cnt < int64(br.blkLen) {
218		errors.Panic(io.ErrUnexpectedEOF)
219	}
220	br.step = (*Reader).readBlockHeader
221}
222
223// readRawData reads raw data according to RFC section 9.2.
224func (br *Reader) readRawData() {
225	buf := br.dict.WriteSlice()
226	if len(buf) > br.blkLen {
227		buf = buf[:br.blkLen]
228	}
229
230	cnt, err := br.rd.Read(buf)
231	br.blkLen -= cnt
232	br.dict.WriteMark(cnt)
233	if err != nil {
234		if err == io.EOF {
235			err = io.ErrUnexpectedEOF
236		}
237		errors.Panic(err)
238	}
239
240	if br.blkLen > 0 {
241		br.toRead = br.dict.ReadFlush()
242		br.step = (*Reader).readRawData // We need to continue this work
243		return
244	}
245	br.step = (*Reader).readBlockHeader
246}
247
248// readPrefixCodes reads the prefix codes according to RFC section 9.2.
249func (br *Reader) readPrefixCodes() {
250	// Read block types for literal, insert-and-copy, and distance blocks.
251	for _, bd := range []*blockDecoder{&br.litBlk, &br.iacBlk, &br.distBlk} {
252		// Note: According to RFC section 6, it is okay for the block count to
253		// *not* count down to zero. Thus, there is no need to validate that
254		// typeLen is within some reasonable range.
255		bd.types = [2]uint8{0, 1}
256		bd.typeLen = -1 // Stay on this type until next meta-block
257
258		bd.numTypes = int(br.rd.ReadSymbol(&decCounts)) // 1..256
259		if bd.numTypes >= 2 {
260			br.rd.ReadPrefixCode(&bd.decType, uint(bd.numTypes)+2)
261			br.rd.ReadPrefixCode(&bd.decLen, uint(numBlkCntSyms))
262			sym := br.rd.ReadSymbol(&bd.decLen)
263			bd.typeLen = int(br.rd.ReadOffset(sym, blkLenRanges))
264		}
265	}
266
267	// Read NPOSTFIX and NDIRECT.
268	npostfix := br.rd.ReadBits(2)            // 0..3
269	ndirect := br.rd.ReadBits(4) << npostfix // 0..120
270	br.npostfix, br.ndirect = uint8(npostfix), uint8(ndirect)
271	numDistSyms := 16 + ndirect + 48<<npostfix
272
273	// Read CMODE, the literal context modes.
274	br.cmodes = allocUint8s(br.cmodes, br.litBlk.numTypes)
275	for i := range br.cmodes {
276		br.cmodes[i] = uint8(br.rd.ReadBits(2))
277	}
278	br.cmode = br.cmodes[0] // 0..3
279
280	// Read CMAPL, the literal context map.
281	numLitTrees := int(br.rd.ReadSymbol(&decCounts)) // 1..256
282	br.litMap = allocUint8s(br.litMap, maxLitContextIDs*br.litBlk.numTypes)
283	if numLitTrees >= 2 {
284		br.readContextMap(br.litMap, uint(numLitTrees))
285	} else {
286		for i := range br.litMap {
287			br.litMap[i] = 0
288		}
289	}
290	br.litMapType = br.litMap[0:] // First block type is zero
291
292	// Read CMAPD, the distance context map.
293	numDistTrees := int(br.rd.ReadSymbol(&decCounts)) // 1..256
294	br.distMap = allocUint8s(br.distMap, maxDistContextIDs*br.distBlk.numTypes)
295	if numDistTrees >= 2 {
296		br.readContextMap(br.distMap, uint(numDistTrees))
297	} else {
298		for i := range br.distMap {
299			br.distMap[i] = 0
300		}
301	}
302	br.distMapType = br.distMap[0:] // First block type is zero
303
304	// Read HTREEL[], HTREEI[], and HTREED[], the arrays of prefix codes.
305	br.litBlk.prefixes = extendDecoders(br.litBlk.prefixes, numLitTrees)
306	for i := range br.litBlk.prefixes {
307		br.rd.ReadPrefixCode(&br.litBlk.prefixes[i], numLitSyms)
308	}
309	br.iacBlk.prefixes = extendDecoders(br.iacBlk.prefixes, br.iacBlk.numTypes)
310	for i := range br.iacBlk.prefixes {
311		br.rd.ReadPrefixCode(&br.iacBlk.prefixes[i], numIaCSyms)
312	}
313	br.distBlk.prefixes = extendDecoders(br.distBlk.prefixes, numDistTrees)
314	for i := range br.distBlk.prefixes {
315		br.rd.ReadPrefixCode(&br.distBlk.prefixes[i], numDistSyms)
316	}
317
318	br.step = (*Reader).readCommands
319}
320
321// readCommands reads block commands according to RFC section 9.3.
322func (br *Reader) readCommands() {
323	// Since Go does not support tail call optimization, we use goto statements
324	// to achieve higher performance processing each command. Each label can be
325	// thought of as a mini function, and each goto as a cheap function call.
326	// The following code follows this control flow.
327	//
328	// The bulk of the action will be in the following loop:
329	//	startCommand -> readLiterals -> readDistance -> copyDynamicDict ->
330	//		finishCommand -> startCommand -> ...
331	/*
332		             readCommands()
333		                   |
334		+----------------> +
335		|                  |
336		|                  V
337		|         +-- startCommand
338		|         |        |
339		|         |        V
340		|         |   readLiterals ----------+
341		|         |        |                 |
342		|         |        V                 |
343		|         +-> readDistance           |
344		|                  |                 |
345		|         +--------+--------+        |
346		|         |                 |        |
347		|         V                 V        |
348		|  copyDynamicDict   copyStaticDict  |
349		|         |                 |        |
350		|         +--------+--------+        |
351		|                  |                 |
352		|                  V                 |
353		+----------- finishCommand <---------+
354		                   |
355		                   V
356		           readBlockHeader()
357	*/
358
359	const (
360		stateInit = iota // Zero value must be stateInit
361
362		// Some labels (readLiterals, copyDynamicDict, copyStaticDict) require
363		// work to be continued if more buffer space is needed. This is achieved
364		// by the  switch block right below, which continues the work at the
365		// right label based on the given sub-step value.
366		stateLiterals
367		stateDynamicDict
368		stateStaticDict
369	)
370
371	switch br.stepState {
372	case stateInit:
373		goto startCommand
374	case stateLiterals:
375		goto readLiterals
376	case stateDynamicDict:
377		goto copyDynamicDict
378	case stateStaticDict:
379		goto copyStaticDict
380	}
381
382startCommand:
383	// Read the insert and copy lengths according to RFC section 5.
384	{
385		if br.iacBlk.typeLen == 0 {
386			br.readBlockSwitch(&br.iacBlk)
387		}
388		br.iacBlk.typeLen--
389
390		iacTree := &br.iacBlk.prefixes[br.iacBlk.types[0]]
391		iacSym, ok := br.rd.TryReadSymbol(iacTree)
392		if !ok {
393			iacSym = br.rd.ReadSymbol(iacTree)
394		}
395		rec := iacLUT[iacSym]
396		insExtra, ok := br.rd.TryReadBits(uint(rec.ins.bits))
397		if !ok {
398			insExtra = br.rd.ReadBits(uint(rec.ins.bits))
399		}
400		cpyExtra, ok := br.rd.TryReadBits(uint(rec.cpy.bits))
401		if !ok {
402			cpyExtra = br.rd.ReadBits(uint(rec.cpy.bits))
403		}
404		br.insLen = int(rec.ins.base) + int(insExtra)
405		br.cpyLen = int(rec.cpy.base) + int(cpyExtra)
406		br.distZero = iacSym < 128
407		if br.insLen > 0 {
408			goto readLiterals
409		}
410		goto readDistance
411	}
412
413readLiterals:
414	// Read literal symbols as uncompressed data according to RFC section 9.3.
415	{
416		buf := br.dict.WriteSlice()
417		if len(buf) > br.insLen {
418			buf = buf[:br.insLen]
419		}
420
421		p1, p2 := br.dict.LastBytes()
422		for i := range buf {
423			if br.litBlk.typeLen == 0 {
424				br.readBlockSwitch(&br.litBlk)
425				br.litMapType = br.litMap[64*int(br.litBlk.types[0]):]
426				br.cmode = br.cmodes[br.litBlk.types[0]] // 0..3
427			}
428			br.litBlk.typeLen--
429
430			litCID := getLitContextID(p1, p2, br.cmode) // 0..63
431			litTree := &br.litBlk.prefixes[br.litMapType[litCID]]
432			litSym, ok := br.rd.TryReadSymbol(litTree)
433			if !ok {
434				litSym = br.rd.ReadSymbol(litTree)
435			}
436
437			buf[i] = byte(litSym)
438			p1, p2 = byte(litSym), p1
439			br.dict.WriteMark(1)
440		}
441		br.insLen -= len(buf)
442		br.blkLen -= len(buf)
443
444		if br.insLen > 0 {
445			br.toRead = br.dict.ReadFlush()
446			br.step = (*Reader).readCommands
447			br.stepState = stateLiterals // Need to continue work here
448			return
449		}
450		if br.blkLen > 0 {
451			goto readDistance
452		}
453		goto finishCommand
454	}
455
456readDistance:
457	// Read and decode the distance length according to RFC section 9.3.
458	{
459		if br.distZero {
460			br.dist = br.dists[0]
461		} else {
462			if br.distBlk.typeLen == 0 {
463				br.readBlockSwitch(&br.distBlk)
464				br.distMapType = br.distMap[4*int(br.distBlk.types[0]):]
465			}
466			br.distBlk.typeLen--
467
468			distCID := getDistContextID(br.cpyLen) // 0..3
469			distTree := &br.distBlk.prefixes[br.distMapType[distCID]]
470			distSym, ok := br.rd.TryReadSymbol(distTree)
471			if !ok {
472				distSym = br.rd.ReadSymbol(distTree)
473			}
474
475			if distSym < 16 { // Short-code
476				rec := distShortLUT[distSym]
477				br.dist = br.dists[rec.index] + rec.delta
478			} else if distSym < uint(16+br.ndirect) { // Direct-code
479				br.dist = int(distSym - 15) // 1..ndirect
480			} else { // Long-code
481				rec := distLongLUT[br.npostfix][distSym-uint(16+br.ndirect)]
482				extra, ok := br.rd.TryReadBits(uint(rec.bits))
483				if !ok {
484					extra = br.rd.ReadBits(uint(rec.bits))
485				}
486				br.dist = int(br.ndirect) + int(rec.base) + int(extra<<br.npostfix)
487			}
488			br.distZero = bool(distSym == 0)
489			if br.dist <= 0 {
490				errors.Panic(errCorrupted)
491			}
492		}
493
494		if br.dist <= br.dict.HistSize() {
495			if !br.distZero {
496				br.dists[3] = br.dists[2]
497				br.dists[2] = br.dists[1]
498				br.dists[1] = br.dists[0]
499				br.dists[0] = br.dist
500			}
501			goto copyDynamicDict
502		}
503		goto copyStaticDict
504	}
505
506copyDynamicDict:
507	// Copy a string from the past uncompressed data according to RFC section 2.
508	{
509		cnt := br.dict.WriteCopy(br.dist, br.cpyLen)
510		br.blkLen -= cnt
511		br.cpyLen -= cnt
512
513		if br.cpyLen > 0 {
514			br.toRead = br.dict.ReadFlush()
515			br.step = (*Reader).readCommands
516			br.stepState = stateDynamicDict // Need to continue work here
517			return
518		}
519		goto finishCommand
520	}
521
522copyStaticDict:
523	// Copy a string from the static dictionary according to RFC section 8.
524	{
525		if len(br.word) == 0 {
526			if br.cpyLen < minDictLen || br.cpyLen > maxDictLen {
527				errors.Panic(errCorrupted)
528			}
529			wordIdx := br.dist - (br.dict.HistSize() + 1)
530			index := wordIdx % dictSizes[br.cpyLen]
531			offset := dictOffsets[br.cpyLen] + index*br.cpyLen
532			baseWord := dictLUT[offset : offset+br.cpyLen]
533			transformIdx := wordIdx >> uint(dictBitSizes[br.cpyLen])
534			if transformIdx >= len(transformLUT) {
535				errors.Panic(errCorrupted)
536			}
537			cnt := transformWord(br.wordBuf[:], baseWord, transformIdx)
538			br.word = br.wordBuf[:cnt]
539		}
540
541		buf := br.dict.WriteSlice()
542		cnt := copy(buf, br.word)
543		br.word = br.word[cnt:]
544		br.blkLen -= cnt
545		br.dict.WriteMark(cnt)
546
547		if len(br.word) > 0 {
548			br.toRead = br.dict.ReadFlush()
549			br.step = (*Reader).readCommands
550			br.stepState = stateStaticDict // Need to continue work here
551			return
552		}
553		goto finishCommand
554	}
555
556finishCommand:
557	// Finish off this command and check if we need to loop again.
558	if br.blkLen < 0 {
559		errors.Panic(errCorrupted)
560	}
561	if br.blkLen > 0 {
562		goto startCommand // More commands in this block
563	}
564
565	// Done with this block.
566	br.toRead = br.dict.ReadFlush()
567	br.step = (*Reader).readBlockHeader
568	br.stepState = stateInit // Next call to readCommands must start here
569}
570
571// readContextMap reads the context map according to RFC section 7.3.
572func (br *Reader) readContextMap(cm []uint8, numTrees uint) {
573	// TODO(dsnet): Test the following edge cases:
574	// * Test with largest and smallest MAXRLE sizes
575	// * Test with with very large MAXRLE value
576	// * Test inverseMoveToFront
577
578	maxRLE := br.rd.ReadSymbol(&decMaxRLE)
579	br.rd.ReadPrefixCode(&br.rd.prefix, maxRLE+numTrees)
580	for i := 0; i < len(cm); {
581		sym := br.rd.ReadSymbol(&br.rd.prefix)
582		if sym == 0 || sym > maxRLE {
583			// Single non-zero value.
584			if sym > 0 {
585				sym -= maxRLE
586			}
587			cm[i] = uint8(sym)
588			i++
589		} else {
590			// Repeated zeros.
591			n := int(br.rd.ReadOffset(sym-1, maxRLERanges))
592			if i+n > len(cm) {
593				errors.Panic(errCorrupted)
594			}
595			for j := i + n; i < j; i++ {
596				cm[i] = 0
597			}
598		}
599	}
600
601	if invert := br.rd.ReadBits(1) == 1; invert {
602		br.mtf.Decode(cm)
603	}
604}
605
606// readBlockSwitch handles a block switch command according to RFC section 6.
607func (br *Reader) readBlockSwitch(bd *blockDecoder) {
608	symType := br.rd.ReadSymbol(&bd.decType)
609	switch symType {
610	case 0:
611		symType = uint(bd.types[1])
612	case 1:
613		symType = uint(bd.types[0]) + 1
614		if symType >= uint(bd.numTypes) {
615			symType -= uint(bd.numTypes)
616		}
617	default:
618		symType -= 2
619	}
620	bd.types = [2]uint8{uint8(symType), bd.types[0]}
621
622	symLen := br.rd.ReadSymbol(&bd.decLen)
623	bd.typeLen = int(br.rd.ReadOffset(symLen, blkLenRanges))
624}
625