1// Copyright 2019+ Klaus Post. All rights reserved.
2// License information can be found in the LICENSE file.
3// Based on work by Yann Collet, released under BSD License.
4
5package zstd
6
7import (
8	"errors"
9	"fmt"
10	"io"
11)
12
13type seq struct {
14	litLen   uint32
15	matchLen uint32
16	offset   uint32
17
18	// Codes are stored here for the encoder
19	// so they only have to be looked up once.
20	llCode, mlCode, ofCode uint8
21}
22
23func (s seq) String() string {
24	if s.offset <= 3 {
25		if s.offset == 0 {
26			return fmt.Sprint("litLen:", s.litLen, ", matchLen:", s.matchLen+zstdMinMatch, ", offset: INVALID (0)")
27		}
28		return fmt.Sprint("litLen:", s.litLen, ", matchLen:", s.matchLen+zstdMinMatch, ", offset:", s.offset, " (repeat)")
29	}
30	return fmt.Sprint("litLen:", s.litLen, ", matchLen:", s.matchLen+zstdMinMatch, ", offset:", s.offset-3, " (new)")
31}
32
33type seqCompMode uint8
34
35const (
36	compModePredefined seqCompMode = iota
37	compModeRLE
38	compModeFSE
39	compModeRepeat
40)
41
42type sequenceDec struct {
43	// decoder keeps track of the current state and updates it from the bitstream.
44	fse    *fseDecoder
45	state  fseState
46	repeat bool
47}
48
49// init the state of the decoder with input from stream.
50func (s *sequenceDec) init(br *bitReader) error {
51	if s.fse == nil {
52		return errors.New("sequence decoder not defined")
53	}
54	s.state.init(br, s.fse.actualTableLog, s.fse.dt[:1<<s.fse.actualTableLog])
55	return nil
56}
57
58// sequenceDecs contains all 3 sequence decoders and their state.
59type sequenceDecs struct {
60	litLengths   sequenceDec
61	offsets      sequenceDec
62	matchLengths sequenceDec
63	prevOffset   [3]int
64	hist         []byte
65	dict         []byte
66	literals     []byte
67	out          []byte
68	windowSize   int
69	maxBits      uint8
70}
71
72// initialize all 3 decoders from the stream input.
73func (s *sequenceDecs) initialize(br *bitReader, hist *history, literals, out []byte) error {
74	if err := s.litLengths.init(br); err != nil {
75		return errors.New("litLengths:" + err.Error())
76	}
77	if err := s.offsets.init(br); err != nil {
78		return errors.New("offsets:" + err.Error())
79	}
80	if err := s.matchLengths.init(br); err != nil {
81		return errors.New("matchLengths:" + err.Error())
82	}
83	s.literals = literals
84	s.hist = hist.b
85	s.prevOffset = hist.recentOffsets
86	s.maxBits = s.litLengths.fse.maxBits + s.offsets.fse.maxBits + s.matchLengths.fse.maxBits
87	s.windowSize = hist.windowSize
88	s.out = out
89	s.dict = nil
90	if hist.dict != nil {
91		s.dict = hist.dict.content
92	}
93	return nil
94}
95
96// decode sequences from the stream with the provided history.
97func (s *sequenceDecs) decode(seqs int, br *bitReader, hist []byte) error {
98	startSize := len(s.out)
99	// Grab full sizes tables, to avoid bounds checks.
100	llTable, mlTable, ofTable := s.litLengths.fse.dt[:maxTablesize], s.matchLengths.fse.dt[:maxTablesize], s.offsets.fse.dt[:maxTablesize]
101	llState, mlState, ofState := s.litLengths.state.state, s.matchLengths.state.state, s.offsets.state.state
102
103	for i := seqs - 1; i >= 0; i-- {
104		if br.overread() {
105			printf("reading sequence %d, exceeded available data\n", seqs-i)
106			return io.ErrUnexpectedEOF
107		}
108		var ll, mo, ml int
109		if br.off > 4+((maxOffsetBits+16+16)>>3) {
110			// inlined function:
111			// ll, mo, ml = s.nextFast(br, llState, mlState, ofState)
112
113			// Final will not read from stream.
114			var llB, mlB, moB uint8
115			ll, llB = llState.final()
116			ml, mlB = mlState.final()
117			mo, moB = ofState.final()
118
119			// extra bits are stored in reverse order.
120			br.fillFast()
121			mo += br.getBits(moB)
122			if s.maxBits > 32 {
123				br.fillFast()
124			}
125			ml += br.getBits(mlB)
126			ll += br.getBits(llB)
127
128			if moB > 1 {
129				s.prevOffset[2] = s.prevOffset[1]
130				s.prevOffset[1] = s.prevOffset[0]
131				s.prevOffset[0] = mo
132			} else {
133				// mo = s.adjustOffset(mo, ll, moB)
134				// Inlined for rather big speedup
135				if ll == 0 {
136					// There is an exception though, when current sequence's literals_length = 0.
137					// In this case, repeated offsets are shifted by one, so an offset_value of 1 means Repeated_Offset2,
138					// an offset_value of 2 means Repeated_Offset3, and an offset_value of 3 means Repeated_Offset1 - 1_byte.
139					mo++
140				}
141
142				if mo == 0 {
143					mo = s.prevOffset[0]
144				} else {
145					var temp int
146					if mo == 3 {
147						temp = s.prevOffset[0] - 1
148					} else {
149						temp = s.prevOffset[mo]
150					}
151
152					if temp == 0 {
153						// 0 is not valid; input is corrupted; force offset to 1
154						println("temp was 0")
155						temp = 1
156					}
157
158					if mo != 1 {
159						s.prevOffset[2] = s.prevOffset[1]
160					}
161					s.prevOffset[1] = s.prevOffset[0]
162					s.prevOffset[0] = temp
163					mo = temp
164				}
165			}
166			br.fillFast()
167		} else {
168			ll, mo, ml = s.next(br, llState, mlState, ofState)
169			br.fill()
170		}
171
172		if debugSequences {
173			println("Seq", seqs-i-1, "Litlen:", ll, "mo:", mo, "(abs) ml:", ml)
174		}
175
176		if ll > len(s.literals) {
177			return fmt.Errorf("unexpected literal count, want %d bytes, but only %d is available", ll, len(s.literals))
178		}
179		size := ll + ml + len(s.out)
180		if size-startSize > maxBlockSize {
181			return fmt.Errorf("output (%d) bigger than max block size", size)
182		}
183		if size > cap(s.out) {
184			// Not enough size, will be extremely rarely triggered,
185			// but could be if destination slice is too small for sync operations.
186			// We add maxBlockSize to the capacity.
187			s.out = append(s.out, make([]byte, maxBlockSize)...)
188			s.out = s.out[:len(s.out)-maxBlockSize]
189		}
190		if ml > maxMatchLen {
191			return fmt.Errorf("match len (%d) bigger than max allowed length", ml)
192		}
193
194		// Add literals
195		s.out = append(s.out, s.literals[:ll]...)
196		s.literals = s.literals[ll:]
197		out := s.out
198
199		if mo == 0 && ml > 0 {
200			return fmt.Errorf("zero matchoff and matchlen (%d) > 0", ml)
201		}
202
203		if mo > len(s.out)+len(hist) || mo > s.windowSize {
204			if len(s.dict) == 0 {
205				return fmt.Errorf("match offset (%d) bigger than current history (%d)", mo, len(s.out)+len(hist))
206			}
207
208			// we may be in dictionary.
209			dictO := len(s.dict) - (mo - (len(s.out) + len(hist)))
210			if dictO < 0 || dictO >= len(s.dict) {
211				return fmt.Errorf("match offset (%d) bigger than current history (%d)", mo, len(s.out)+len(hist))
212			}
213			end := dictO + ml
214			if end > len(s.dict) {
215				out = append(out, s.dict[dictO:]...)
216				mo -= len(s.dict) - dictO
217				ml -= len(s.dict) - dictO
218			} else {
219				out = append(out, s.dict[dictO:end]...)
220				mo = 0
221				ml = 0
222			}
223		}
224
225		// Copy from history.
226		// TODO: Blocks without history could be made to ignore this completely.
227		if v := mo - len(s.out); v > 0 {
228			// v is the start position in history from end.
229			start := len(s.hist) - v
230			if ml > v {
231				// Some goes into current block.
232				// Copy remainder of history
233				out = append(out, s.hist[start:]...)
234				mo -= v
235				ml -= v
236			} else {
237				out = append(out, s.hist[start:start+ml]...)
238				ml = 0
239			}
240		}
241		// We must be in current buffer now
242		if ml > 0 {
243			start := len(s.out) - mo
244			if ml <= len(s.out)-start {
245				// No overlap
246				out = append(out, s.out[start:start+ml]...)
247			} else {
248				// Overlapping copy
249				// Extend destination slice and copy one byte at the time.
250				out = out[:len(out)+ml]
251				src := out[start : start+ml]
252				// Destination is the space we just added.
253				dst := out[len(out)-ml:]
254				dst = dst[:len(src)]
255				for i := range src {
256					dst[i] = src[i]
257				}
258			}
259		}
260		s.out = out
261		if i == 0 {
262			// This is the last sequence, so we shouldn't update state.
263			break
264		}
265
266		// Manually inlined, ~ 5-20% faster
267		// Update all 3 states at once. Approx 20% faster.
268		nBits := llState.nbBits() + mlState.nbBits() + ofState.nbBits()
269		if nBits == 0 {
270			llState = llTable[llState.newState()&maxTableMask]
271			mlState = mlTable[mlState.newState()&maxTableMask]
272			ofState = ofTable[ofState.newState()&maxTableMask]
273		} else {
274			bits := br.getBitsFast(nBits)
275			lowBits := uint16(bits >> ((ofState.nbBits() + mlState.nbBits()) & 31))
276			llState = llTable[(llState.newState()+lowBits)&maxTableMask]
277
278			lowBits = uint16(bits >> (ofState.nbBits() & 31))
279			lowBits &= bitMask[mlState.nbBits()&15]
280			mlState = mlTable[(mlState.newState()+lowBits)&maxTableMask]
281
282			lowBits = uint16(bits) & bitMask[ofState.nbBits()&15]
283			ofState = ofTable[(ofState.newState()+lowBits)&maxTableMask]
284		}
285	}
286
287	// Add final literals
288	s.out = append(s.out, s.literals...)
289	return nil
290}
291
292// update states, at least 27 bits must be available.
293func (s *sequenceDecs) update(br *bitReader) {
294	// Max 8 bits
295	s.litLengths.state.next(br)
296	// Max 9 bits
297	s.matchLengths.state.next(br)
298	// Max 8 bits
299	s.offsets.state.next(br)
300}
301
302var bitMask [16]uint16
303
304func init() {
305	for i := range bitMask[:] {
306		bitMask[i] = uint16((1 << uint(i)) - 1)
307	}
308}
309
310// update states, at least 27 bits must be available.
311func (s *sequenceDecs) updateAlt(br *bitReader) {
312	// Update all 3 states at once. Approx 20% faster.
313	a, b, c := s.litLengths.state.state, s.matchLengths.state.state, s.offsets.state.state
314
315	nBits := a.nbBits() + b.nbBits() + c.nbBits()
316	if nBits == 0 {
317		s.litLengths.state.state = s.litLengths.state.dt[a.newState()]
318		s.matchLengths.state.state = s.matchLengths.state.dt[b.newState()]
319		s.offsets.state.state = s.offsets.state.dt[c.newState()]
320		return
321	}
322	bits := br.getBitsFast(nBits)
323	lowBits := uint16(bits >> ((c.nbBits() + b.nbBits()) & 31))
324	s.litLengths.state.state = s.litLengths.state.dt[a.newState()+lowBits]
325
326	lowBits = uint16(bits >> (c.nbBits() & 31))
327	lowBits &= bitMask[b.nbBits()&15]
328	s.matchLengths.state.state = s.matchLengths.state.dt[b.newState()+lowBits]
329
330	lowBits = uint16(bits) & bitMask[c.nbBits()&15]
331	s.offsets.state.state = s.offsets.state.dt[c.newState()+lowBits]
332}
333
334// nextFast will return new states when there are at least 4 unused bytes left on the stream when done.
335func (s *sequenceDecs) nextFast(br *bitReader, llState, mlState, ofState decSymbol) (ll, mo, ml int) {
336	// Final will not read from stream.
337	ll, llB := llState.final()
338	ml, mlB := mlState.final()
339	mo, moB := ofState.final()
340
341	// extra bits are stored in reverse order.
342	br.fillFast()
343	mo += br.getBits(moB)
344	if s.maxBits > 32 {
345		br.fillFast()
346	}
347	ml += br.getBits(mlB)
348	ll += br.getBits(llB)
349
350	if moB > 1 {
351		s.prevOffset[2] = s.prevOffset[1]
352		s.prevOffset[1] = s.prevOffset[0]
353		s.prevOffset[0] = mo
354		return
355	}
356	// mo = s.adjustOffset(mo, ll, moB)
357	// Inlined for rather big speedup
358	if ll == 0 {
359		// There is an exception though, when current sequence's literals_length = 0.
360		// In this case, repeated offsets are shifted by one, so an offset_value of 1 means Repeated_Offset2,
361		// an offset_value of 2 means Repeated_Offset3, and an offset_value of 3 means Repeated_Offset1 - 1_byte.
362		mo++
363	}
364
365	if mo == 0 {
366		mo = s.prevOffset[0]
367		return
368	}
369	var temp int
370	if mo == 3 {
371		temp = s.prevOffset[0] - 1
372	} else {
373		temp = s.prevOffset[mo]
374	}
375
376	if temp == 0 {
377		// 0 is not valid; input is corrupted; force offset to 1
378		println("temp was 0")
379		temp = 1
380	}
381
382	if mo != 1 {
383		s.prevOffset[2] = s.prevOffset[1]
384	}
385	s.prevOffset[1] = s.prevOffset[0]
386	s.prevOffset[0] = temp
387	mo = temp
388	return
389}
390
391func (s *sequenceDecs) next(br *bitReader, llState, mlState, ofState decSymbol) (ll, mo, ml int) {
392	// Final will not read from stream.
393	ll, llB := llState.final()
394	ml, mlB := mlState.final()
395	mo, moB := ofState.final()
396
397	// extra bits are stored in reverse order.
398	br.fill()
399	if s.maxBits <= 32 {
400		mo += br.getBits(moB)
401		ml += br.getBits(mlB)
402		ll += br.getBits(llB)
403	} else {
404		mo += br.getBits(moB)
405		br.fill()
406		// matchlength+literal length, max 32 bits
407		ml += br.getBits(mlB)
408		ll += br.getBits(llB)
409
410	}
411	mo = s.adjustOffset(mo, ll, moB)
412	return
413}
414
415func (s *sequenceDecs) adjustOffset(offset, litLen int, offsetB uint8) int {
416	if offsetB > 1 {
417		s.prevOffset[2] = s.prevOffset[1]
418		s.prevOffset[1] = s.prevOffset[0]
419		s.prevOffset[0] = offset
420		return offset
421	}
422
423	if litLen == 0 {
424		// There is an exception though, when current sequence's literals_length = 0.
425		// In this case, repeated offsets are shifted by one, so an offset_value of 1 means Repeated_Offset2,
426		// an offset_value of 2 means Repeated_Offset3, and an offset_value of 3 means Repeated_Offset1 - 1_byte.
427		offset++
428	}
429
430	if offset == 0 {
431		return s.prevOffset[0]
432	}
433	var temp int
434	if offset == 3 {
435		temp = s.prevOffset[0] - 1
436	} else {
437		temp = s.prevOffset[offset]
438	}
439
440	if temp == 0 {
441		// 0 is not valid; input is corrupted; force offset to 1
442		println("temp was 0")
443		temp = 1
444	}
445
446	if offset != 1 {
447		s.prevOffset[2] = s.prevOffset[1]
448	}
449	s.prevOffset[1] = s.prevOffset[0]
450	s.prevOffset[0] = temp
451	return temp
452}
453
454// mergeHistory will merge history.
455func (s *sequenceDecs) mergeHistory(hist *sequenceDecs) (*sequenceDecs, error) {
456	for i := uint(0); i < 3; i++ {
457		var sNew, sHist *sequenceDec
458		switch i {
459		default:
460			// same as "case 0":
461			sNew = &s.litLengths
462			sHist = &hist.litLengths
463		case 1:
464			sNew = &s.offsets
465			sHist = &hist.offsets
466		case 2:
467			sNew = &s.matchLengths
468			sHist = &hist.matchLengths
469		}
470		if sNew.repeat {
471			if sHist.fse == nil {
472				return nil, fmt.Errorf("sequence stream %d, repeat requested, but no history", i)
473			}
474			continue
475		}
476		if sNew.fse == nil {
477			return nil, fmt.Errorf("sequence stream %d, no fse found", i)
478		}
479		if sHist.fse != nil && !sHist.fse.preDefined {
480			fseDecoderPool.Put(sHist.fse)
481		}
482		sHist.fse = sNew.fse
483	}
484	return hist, nil
485}
486