1// Copyright 2019+ Klaus Post. All rights reserved.
2// License information can be found in the LICENSE file.
3// Based on work by Yann Collet, released under BSD License.
4
5package zstd
6
7import (
8	"fmt"
9	"math"
10	"sync"
11)
12
13var (
14	// fsePredef are the predefined fse tables as defined here:
15	// https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md#default-distributions
16	// These values are already transformed.
17	fsePredef [3]fseDecoder
18
19	// fsePredefEnc are the predefined encoder based on fse tables as defined here:
20	// https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md#default-distributions
21	// These values are already transformed.
22	fsePredefEnc [3]fseEncoder
23
24	// symbolTableX contain the transformations needed for each type as defined in
25	// https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md#the-codes-for-literals-lengths-match-lengths-and-offsets
26	symbolTableX [3][]baseOffset
27
28	// maxTableSymbol is the biggest supported symbol for each table type
29	// https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md#the-codes-for-literals-lengths-match-lengths-and-offsets
30	maxTableSymbol = [3]uint8{tableLiteralLengths: maxLiteralLengthSymbol, tableOffsets: maxOffsetLengthSymbol, tableMatchLengths: maxMatchLengthSymbol}
31
32	// bitTables is the bits table for each table.
33	bitTables = [3][]byte{tableLiteralLengths: llBitsTable[:], tableOffsets: nil, tableMatchLengths: mlBitsTable[:]}
34)
35
36type tableIndex uint8
37
38const (
39	// indexes for fsePredef and symbolTableX
40	tableLiteralLengths tableIndex = 0
41	tableOffsets        tableIndex = 1
42	tableMatchLengths   tableIndex = 2
43
44	maxLiteralLengthSymbol = 35
45	maxOffsetLengthSymbol  = 30
46	maxMatchLengthSymbol   = 52
47)
48
49// baseOffset is used for calculating transformations.
50type baseOffset struct {
51	baseLine uint32
52	addBits  uint8
53}
54
55// fillBase will precalculate base offsets with the given bit distributions.
56func fillBase(dst []baseOffset, base uint32, bits ...uint8) {
57	if len(bits) != len(dst) {
58		panic(fmt.Sprintf("len(dst) (%d) != len(bits) (%d)", len(dst), len(bits)))
59	}
60	for i, bit := range bits {
61		if base > math.MaxInt32 {
62			panic(fmt.Sprintf("invalid decoding table, base overflows int32"))
63		}
64
65		dst[i] = baseOffset{
66			baseLine: base,
67			addBits:  bit,
68		}
69		base += 1 << bit
70	}
71}
72
73var predef sync.Once
74
75func initPredefined() {
76	predef.Do(func() {
77		// Literals length codes
78		tmp := make([]baseOffset, 36)
79		for i := range tmp[:16] {
80			tmp[i] = baseOffset{
81				baseLine: uint32(i),
82				addBits:  0,
83			}
84		}
85		fillBase(tmp[16:], 16, 1, 1, 1, 1, 2, 2, 3, 3, 4, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16)
86		symbolTableX[tableLiteralLengths] = tmp
87
88		// Match length codes
89		tmp = make([]baseOffset, 53)
90		for i := range tmp[:32] {
91			tmp[i] = baseOffset{
92				// The transformation adds the 3 length.
93				baseLine: uint32(i) + 3,
94				addBits:  0,
95			}
96		}
97		fillBase(tmp[32:], 35, 1, 1, 1, 1, 2, 2, 3, 3, 4, 4, 5, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16)
98		symbolTableX[tableMatchLengths] = tmp
99
100		// Offset codes
101		tmp = make([]baseOffset, maxOffsetBits+1)
102		tmp[1] = baseOffset{
103			baseLine: 1,
104			addBits:  1,
105		}
106		fillBase(tmp[2:], 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30)
107		symbolTableX[tableOffsets] = tmp
108
109		// Fill predefined tables and transform them.
110		// https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md#default-distributions
111		for i := range fsePredef[:] {
112			f := &fsePredef[i]
113			switch tableIndex(i) {
114			case tableLiteralLengths:
115				// https://github.com/facebook/zstd/blob/ededcfca57366461021c922720878c81a5854a0a/lib/decompress/zstd_decompress_block.c#L243
116				f.actualTableLog = 6
117				copy(f.norm[:], []int16{4, 3, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 1, 1,
118					2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 2, 1, 1, 1, 1, 1,
119					-1, -1, -1, -1})
120				f.symbolLen = 36
121			case tableOffsets:
122				// https://github.com/facebook/zstd/blob/ededcfca57366461021c922720878c81a5854a0a/lib/decompress/zstd_decompress_block.c#L281
123				f.actualTableLog = 5
124				copy(f.norm[:], []int16{
125					1, 1, 1, 1, 1, 1, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1,
126					1, 1, 1, 1, 1, 1, 1, 1, -1, -1, -1, -1, -1})
127				f.symbolLen = 29
128			case tableMatchLengths:
129				//https://github.com/facebook/zstd/blob/ededcfca57366461021c922720878c81a5854a0a/lib/decompress/zstd_decompress_block.c#L304
130				f.actualTableLog = 6
131				copy(f.norm[:], []int16{
132					1, 4, 3, 2, 2, 2, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1,
133					1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
134					1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, -1, -1,
135					-1, -1, -1, -1, -1})
136				f.symbolLen = 53
137			}
138			if err := f.buildDtable(); err != nil {
139				panic(fmt.Errorf("building table %v: %v", tableIndex(i), err))
140			}
141			if err := f.transform(symbolTableX[i]); err != nil {
142				panic(fmt.Errorf("building table %v: %v", tableIndex(i), err))
143			}
144			f.preDefined = true
145
146			// Create encoder as well
147			enc := &fsePredefEnc[i]
148			copy(enc.norm[:], f.norm[:])
149			enc.symbolLen = f.symbolLen
150			enc.actualTableLog = f.actualTableLog
151			if err := enc.buildCTable(); err != nil {
152				panic(fmt.Errorf("building encoding table %v: %v", tableIndex(i), err))
153			}
154			enc.setBits(bitTables[i])
155			enc.preDefined = true
156		}
157	})
158}
159