1// Copyright 2019+ Klaus Post. All rights reserved.
2// License information can be found in the LICENSE file.
3// Based on work by Yann Collet, released under BSD License.
4
5package zstd
6
7import (
8	"errors"
9	"fmt"
10	"io"
11)
12
13type seq struct {
14	litLen   uint32
15	matchLen uint32
16	offset   uint32
17
18	// Codes are stored here for the encoder
19	// so they only have to be looked up once.
20	llCode, mlCode, ofCode uint8
21}
22
23func (s seq) String() string {
24	if s.offset <= 3 {
25		if s.offset == 0 {
26			return fmt.Sprint("litLen:", s.litLen, ", matchLen:", s.matchLen+zstdMinMatch, ", offset: INVALID (0)")
27		}
28		return fmt.Sprint("litLen:", s.litLen, ", matchLen:", s.matchLen+zstdMinMatch, ", offset:", s.offset, " (repeat)")
29	}
30	return fmt.Sprint("litLen:", s.litLen, ", matchLen:", s.matchLen+zstdMinMatch, ", offset:", s.offset-3, " (new)")
31}
32
33type seqCompMode uint8
34
35const (
36	compModePredefined seqCompMode = iota
37	compModeRLE
38	compModeFSE
39	compModeRepeat
40)
41
42type sequenceDec struct {
43	// decoder keeps track of the current state and updates it from the bitstream.
44	fse    *fseDecoder
45	state  fseState
46	repeat bool
47}
48
49// init the state of the decoder with input from stream.
50func (s *sequenceDec) init(br *bitReader) error {
51	if s.fse == nil {
52		return errors.New("sequence decoder not defined")
53	}
54	s.state.init(br, s.fse.actualTableLog, s.fse.dt[:1<<s.fse.actualTableLog])
55	return nil
56}
57
58// sequenceDecs contains all 3 sequence decoders and their state.
59type sequenceDecs struct {
60	litLengths   sequenceDec
61	offsets      sequenceDec
62	matchLengths sequenceDec
63	prevOffset   [3]int
64	hist         []byte
65	dict         []byte
66	literals     []byte
67	out          []byte
68	windowSize   int
69	maxBits      uint8
70}
71
72// initialize all 3 decoders from the stream input.
73func (s *sequenceDecs) initialize(br *bitReader, hist *history, literals, out []byte) error {
74	if err := s.litLengths.init(br); err != nil {
75		return errors.New("litLengths:" + err.Error())
76	}
77	if err := s.offsets.init(br); err != nil {
78		return errors.New("offsets:" + err.Error())
79	}
80	if err := s.matchLengths.init(br); err != nil {
81		return errors.New("matchLengths:" + err.Error())
82	}
83	s.literals = literals
84	s.hist = hist.b
85	s.prevOffset = hist.recentOffsets
86	s.maxBits = s.litLengths.fse.maxBits + s.offsets.fse.maxBits + s.matchLengths.fse.maxBits
87	s.windowSize = hist.windowSize
88	s.out = out
89	s.dict = nil
90	if hist.dict != nil {
91		s.dict = hist.dict.content
92	}
93	return nil
94}
95
96// decode sequences from the stream with the provided history.
97func (s *sequenceDecs) decode(seqs int, br *bitReader, hist []byte) error {
98	startSize := len(s.out)
99	// Grab full sizes tables, to avoid bounds checks.
100	llTable, mlTable, ofTable := s.litLengths.fse.dt[:maxTablesize], s.matchLengths.fse.dt[:maxTablesize], s.offsets.fse.dt[:maxTablesize]
101	llState, mlState, ofState := s.litLengths.state.state, s.matchLengths.state.state, s.offsets.state.state
102
103	for i := seqs - 1; i >= 0; i-- {
104		if br.overread() {
105			printf("reading sequence %d, exceeded available data\n", seqs-i)
106			return io.ErrUnexpectedEOF
107		}
108		var ll, mo, ml int
109		if br.off > 4+((maxOffsetBits+16+16)>>3) {
110			// inlined function:
111			// ll, mo, ml = s.nextFast(br, llState, mlState, ofState)
112
113			// Final will not read from stream.
114			var llB, mlB, moB uint8
115			ll, llB = llState.final()
116			ml, mlB = mlState.final()
117			mo, moB = ofState.final()
118
119			// extra bits are stored in reverse order.
120			br.fillFast()
121			mo += br.getBits(moB)
122			if s.maxBits > 32 {
123				br.fillFast()
124			}
125			ml += br.getBits(mlB)
126			ll += br.getBits(llB)
127
128			if moB > 1 {
129				s.prevOffset[2] = s.prevOffset[1]
130				s.prevOffset[1] = s.prevOffset[0]
131				s.prevOffset[0] = mo
132			} else {
133				// mo = s.adjustOffset(mo, ll, moB)
134				// Inlined for rather big speedup
135				if ll == 0 {
136					// There is an exception though, when current sequence's literals_length = 0.
137					// In this case, repeated offsets are shifted by one, so an offset_value of 1 means Repeated_Offset2,
138					// an offset_value of 2 means Repeated_Offset3, and an offset_value of 3 means Repeated_Offset1 - 1_byte.
139					mo++
140				}
141
142				if mo == 0 {
143					mo = s.prevOffset[0]
144				} else {
145					var temp int
146					if mo == 3 {
147						temp = s.prevOffset[0] - 1
148					} else {
149						temp = s.prevOffset[mo]
150					}
151
152					if temp == 0 {
153						// 0 is not valid; input is corrupted; force offset to 1
154						println("temp was 0")
155						temp = 1
156					}
157
158					if mo != 1 {
159						s.prevOffset[2] = s.prevOffset[1]
160					}
161					s.prevOffset[1] = s.prevOffset[0]
162					s.prevOffset[0] = temp
163					mo = temp
164				}
165			}
166			br.fillFast()
167		} else {
168			ll, mo, ml = s.next(br, llState, mlState, ofState)
169			br.fill()
170		}
171
172		if debugSequences {
173			println("Seq", seqs-i-1, "Litlen:", ll, "mo:", mo, "(abs) ml:", ml)
174		}
175
176		if ll > len(s.literals) {
177			return fmt.Errorf("unexpected literal count, want %d bytes, but only %d is available", ll, len(s.literals))
178		}
179		size := ll + ml + len(s.out)
180		if size-startSize > maxBlockSize {
181			return fmt.Errorf("output (%d) bigger than max block size", size)
182		}
183		if size > cap(s.out) {
184			// Not enough size, which can happen under high volume block streaming conditions
185			// but could be if destination slice is too small for sync operations.
186			// over-allocating here can create a large amount of GC pressure so we try to keep
187			// it as contained as possible
188			used := len(s.out) - startSize
189			addBytes := 256 + ll + ml + used>>2
190			// Clamp to max block size.
191			if used+addBytes > maxBlockSize {
192				addBytes = maxBlockSize - used
193			}
194			s.out = append(s.out, make([]byte, addBytes)...)
195			s.out = s.out[:len(s.out)-addBytes]
196		}
197		if ml > maxMatchLen {
198			return fmt.Errorf("match len (%d) bigger than max allowed length", ml)
199		}
200
201		// Add literals
202		s.out = append(s.out, s.literals[:ll]...)
203		s.literals = s.literals[ll:]
204		out := s.out
205
206		if mo == 0 && ml > 0 {
207			return fmt.Errorf("zero matchoff and matchlen (%d) > 0", ml)
208		}
209
210		if mo > len(s.out)+len(hist) || mo > s.windowSize {
211			if len(s.dict) == 0 {
212				return fmt.Errorf("match offset (%d) bigger than current history (%d)", mo, len(s.out)+len(hist))
213			}
214
215			// we may be in dictionary.
216			dictO := len(s.dict) - (mo - (len(s.out) + len(hist)))
217			if dictO < 0 || dictO >= len(s.dict) {
218				return fmt.Errorf("match offset (%d) bigger than current history (%d)", mo, len(s.out)+len(hist))
219			}
220			end := dictO + ml
221			if end > len(s.dict) {
222				out = append(out, s.dict[dictO:]...)
223				mo -= len(s.dict) - dictO
224				ml -= len(s.dict) - dictO
225			} else {
226				out = append(out, s.dict[dictO:end]...)
227				mo = 0
228				ml = 0
229			}
230		}
231
232		// Copy from history.
233		// TODO: Blocks without history could be made to ignore this completely.
234		if v := mo - len(s.out); v > 0 {
235			// v is the start position in history from end.
236			start := len(s.hist) - v
237			if ml > v {
238				// Some goes into current block.
239				// Copy remainder of history
240				out = append(out, s.hist[start:]...)
241				mo -= v
242				ml -= v
243			} else {
244				out = append(out, s.hist[start:start+ml]...)
245				ml = 0
246			}
247		}
248		// We must be in current buffer now
249		if ml > 0 {
250			start := len(s.out) - mo
251			if ml <= len(s.out)-start {
252				// No overlap
253				out = append(out, s.out[start:start+ml]...)
254			} else {
255				// Overlapping copy
256				// Extend destination slice and copy one byte at the time.
257				out = out[:len(out)+ml]
258				src := out[start : start+ml]
259				// Destination is the space we just added.
260				dst := out[len(out)-ml:]
261				dst = dst[:len(src)]
262				for i := range src {
263					dst[i] = src[i]
264				}
265			}
266		}
267		s.out = out
268		if i == 0 {
269			// This is the last sequence, so we shouldn't update state.
270			break
271		}
272
273		// Manually inlined, ~ 5-20% faster
274		// Update all 3 states at once. Approx 20% faster.
275		nBits := llState.nbBits() + mlState.nbBits() + ofState.nbBits()
276		if nBits == 0 {
277			llState = llTable[llState.newState()&maxTableMask]
278			mlState = mlTable[mlState.newState()&maxTableMask]
279			ofState = ofTable[ofState.newState()&maxTableMask]
280		} else {
281			bits := br.getBitsFast(nBits)
282			lowBits := uint16(bits >> ((ofState.nbBits() + mlState.nbBits()) & 31))
283			llState = llTable[(llState.newState()+lowBits)&maxTableMask]
284
285			lowBits = uint16(bits >> (ofState.nbBits() & 31))
286			lowBits &= bitMask[mlState.nbBits()&15]
287			mlState = mlTable[(mlState.newState()+lowBits)&maxTableMask]
288
289			lowBits = uint16(bits) & bitMask[ofState.nbBits()&15]
290			ofState = ofTable[(ofState.newState()+lowBits)&maxTableMask]
291		}
292	}
293
294	// Add final literals
295	s.out = append(s.out, s.literals...)
296	return nil
297}
298
299// update states, at least 27 bits must be available.
300func (s *sequenceDecs) update(br *bitReader) {
301	// Max 8 bits
302	s.litLengths.state.next(br)
303	// Max 9 bits
304	s.matchLengths.state.next(br)
305	// Max 8 bits
306	s.offsets.state.next(br)
307}
308
309var bitMask [16]uint16
310
311func init() {
312	for i := range bitMask[:] {
313		bitMask[i] = uint16((1 << uint(i)) - 1)
314	}
315}
316
317// update states, at least 27 bits must be available.
318func (s *sequenceDecs) updateAlt(br *bitReader) {
319	// Update all 3 states at once. Approx 20% faster.
320	a, b, c := s.litLengths.state.state, s.matchLengths.state.state, s.offsets.state.state
321
322	nBits := a.nbBits() + b.nbBits() + c.nbBits()
323	if nBits == 0 {
324		s.litLengths.state.state = s.litLengths.state.dt[a.newState()]
325		s.matchLengths.state.state = s.matchLengths.state.dt[b.newState()]
326		s.offsets.state.state = s.offsets.state.dt[c.newState()]
327		return
328	}
329	bits := br.getBitsFast(nBits)
330	lowBits := uint16(bits >> ((c.nbBits() + b.nbBits()) & 31))
331	s.litLengths.state.state = s.litLengths.state.dt[a.newState()+lowBits]
332
333	lowBits = uint16(bits >> (c.nbBits() & 31))
334	lowBits &= bitMask[b.nbBits()&15]
335	s.matchLengths.state.state = s.matchLengths.state.dt[b.newState()+lowBits]
336
337	lowBits = uint16(bits) & bitMask[c.nbBits()&15]
338	s.offsets.state.state = s.offsets.state.dt[c.newState()+lowBits]
339}
340
341// nextFast will return new states when there are at least 4 unused bytes left on the stream when done.
342func (s *sequenceDecs) nextFast(br *bitReader, llState, mlState, ofState decSymbol) (ll, mo, ml int) {
343	// Final will not read from stream.
344	ll, llB := llState.final()
345	ml, mlB := mlState.final()
346	mo, moB := ofState.final()
347
348	// extra bits are stored in reverse order.
349	br.fillFast()
350	mo += br.getBits(moB)
351	if s.maxBits > 32 {
352		br.fillFast()
353	}
354	ml += br.getBits(mlB)
355	ll += br.getBits(llB)
356
357	if moB > 1 {
358		s.prevOffset[2] = s.prevOffset[1]
359		s.prevOffset[1] = s.prevOffset[0]
360		s.prevOffset[0] = mo
361		return
362	}
363	// mo = s.adjustOffset(mo, ll, moB)
364	// Inlined for rather big speedup
365	if ll == 0 {
366		// There is an exception though, when current sequence's literals_length = 0.
367		// In this case, repeated offsets are shifted by one, so an offset_value of 1 means Repeated_Offset2,
368		// an offset_value of 2 means Repeated_Offset3, and an offset_value of 3 means Repeated_Offset1 - 1_byte.
369		mo++
370	}
371
372	if mo == 0 {
373		mo = s.prevOffset[0]
374		return
375	}
376	var temp int
377	if mo == 3 {
378		temp = s.prevOffset[0] - 1
379	} else {
380		temp = s.prevOffset[mo]
381	}
382
383	if temp == 0 {
384		// 0 is not valid; input is corrupted; force offset to 1
385		println("temp was 0")
386		temp = 1
387	}
388
389	if mo != 1 {
390		s.prevOffset[2] = s.prevOffset[1]
391	}
392	s.prevOffset[1] = s.prevOffset[0]
393	s.prevOffset[0] = temp
394	mo = temp
395	return
396}
397
398func (s *sequenceDecs) next(br *bitReader, llState, mlState, ofState decSymbol) (ll, mo, ml int) {
399	// Final will not read from stream.
400	ll, llB := llState.final()
401	ml, mlB := mlState.final()
402	mo, moB := ofState.final()
403
404	// extra bits are stored in reverse order.
405	br.fill()
406	if s.maxBits <= 32 {
407		mo += br.getBits(moB)
408		ml += br.getBits(mlB)
409		ll += br.getBits(llB)
410	} else {
411		mo += br.getBits(moB)
412		br.fill()
413		// matchlength+literal length, max 32 bits
414		ml += br.getBits(mlB)
415		ll += br.getBits(llB)
416
417	}
418	mo = s.adjustOffset(mo, ll, moB)
419	return
420}
421
422func (s *sequenceDecs) adjustOffset(offset, litLen int, offsetB uint8) int {
423	if offsetB > 1 {
424		s.prevOffset[2] = s.prevOffset[1]
425		s.prevOffset[1] = s.prevOffset[0]
426		s.prevOffset[0] = offset
427		return offset
428	}
429
430	if litLen == 0 {
431		// There is an exception though, when current sequence's literals_length = 0.
432		// In this case, repeated offsets are shifted by one, so an offset_value of 1 means Repeated_Offset2,
433		// an offset_value of 2 means Repeated_Offset3, and an offset_value of 3 means Repeated_Offset1 - 1_byte.
434		offset++
435	}
436
437	if offset == 0 {
438		return s.prevOffset[0]
439	}
440	var temp int
441	if offset == 3 {
442		temp = s.prevOffset[0] - 1
443	} else {
444		temp = s.prevOffset[offset]
445	}
446
447	if temp == 0 {
448		// 0 is not valid; input is corrupted; force offset to 1
449		println("temp was 0")
450		temp = 1
451	}
452
453	if offset != 1 {
454		s.prevOffset[2] = s.prevOffset[1]
455	}
456	s.prevOffset[1] = s.prevOffset[0]
457	s.prevOffset[0] = temp
458	return temp
459}
460
461// mergeHistory will merge history.
462func (s *sequenceDecs) mergeHistory(hist *sequenceDecs) (*sequenceDecs, error) {
463	for i := uint(0); i < 3; i++ {
464		var sNew, sHist *sequenceDec
465		switch i {
466		default:
467			// same as "case 0":
468			sNew = &s.litLengths
469			sHist = &hist.litLengths
470		case 1:
471			sNew = &s.offsets
472			sHist = &hist.offsets
473		case 2:
474			sNew = &s.matchLengths
475			sHist = &hist.matchLengths
476		}
477		if sNew.repeat {
478			if sHist.fse == nil {
479				return nil, fmt.Errorf("sequence stream %d, repeat requested, but no history", i)
480			}
481			continue
482		}
483		if sNew.fse == nil {
484			return nil, fmt.Errorf("sequence stream %d, no fse found", i)
485		}
486		if sHist.fse != nil && !sHist.fse.preDefined {
487			fseDecoderPool.Put(sHist.fse)
488		}
489		sHist.fse = sNew.fse
490	}
491	return hist, nil
492}
493