1package zstd
2
3import (
4	"bytes"
5	"encoding/binary"
6	"errors"
7	"fmt"
8	"io"
9
10	"github.com/klauspost/compress/huff0"
11)
12
13type dict struct {
14	id uint32
15
16	litEnc              *huff0.Scratch
17	llDec, ofDec, mlDec sequenceDec
18	//llEnc, ofEnc, mlEnc []*fseEncoder
19	offsets [3]int
20	content []byte
21}
22
23var dictMagic = [4]byte{0x37, 0xa4, 0x30, 0xec}
24
25// ID returns the dictionary id or 0 if d is nil.
26func (d *dict) ID() uint32 {
27	if d == nil {
28		return 0
29	}
30	return d.id
31}
32
33// DictContentSize returns the dictionary content size or 0 if d is nil.
34func (d *dict) DictContentSize() int {
35	if d == nil {
36		return 0
37	}
38	return len(d.content)
39}
40
41// Load a dictionary as described in
42// https://github.com/facebook/zstd/blob/master/doc/zstd_compression_format.md#dictionary-format
43func loadDict(b []byte) (*dict, error) {
44	// Check static field size.
45	if len(b) <= 8+(3*4) {
46		return nil, io.ErrUnexpectedEOF
47	}
48	d := dict{
49		llDec: sequenceDec{fse: &fseDecoder{}},
50		ofDec: sequenceDec{fse: &fseDecoder{}},
51		mlDec: sequenceDec{fse: &fseDecoder{}},
52	}
53	if !bytes.Equal(b[:4], dictMagic[:]) {
54		return nil, ErrMagicMismatch
55	}
56	d.id = binary.LittleEndian.Uint32(b[4:8])
57	if d.id == 0 {
58		return nil, errors.New("dictionaries cannot have ID 0")
59	}
60
61	// Read literal table
62	var err error
63	d.litEnc, b, err = huff0.ReadTable(b[8:], nil)
64	if err != nil {
65		return nil, err
66	}
67	d.litEnc.Reuse = huff0.ReusePolicyMust
68
69	br := byteReader{
70		b:   b,
71		off: 0,
72	}
73	readDec := func(i tableIndex, dec *fseDecoder) error {
74		if err := dec.readNCount(&br, uint16(maxTableSymbol[i])); err != nil {
75			return err
76		}
77		if br.overread() {
78			return io.ErrUnexpectedEOF
79		}
80		err = dec.transform(symbolTableX[i])
81		if err != nil {
82			println("Transform table error:", err)
83			return err
84		}
85		if debug {
86			println("Read table ok", "symbolLen:", dec.symbolLen)
87		}
88		// Set decoders as predefined so they aren't reused.
89		dec.preDefined = true
90		return nil
91	}
92
93	if err := readDec(tableOffsets, d.ofDec.fse); err != nil {
94		return nil, err
95	}
96	if err := readDec(tableMatchLengths, d.mlDec.fse); err != nil {
97		return nil, err
98	}
99	if err := readDec(tableLiteralLengths, d.llDec.fse); err != nil {
100		return nil, err
101	}
102	if br.remain() < 12 {
103		return nil, io.ErrUnexpectedEOF
104	}
105
106	d.offsets[0] = int(br.Uint32())
107	br.advance(4)
108	d.offsets[1] = int(br.Uint32())
109	br.advance(4)
110	d.offsets[2] = int(br.Uint32())
111	br.advance(4)
112	if d.offsets[0] <= 0 || d.offsets[1] <= 0 || d.offsets[2] <= 0 {
113		return nil, errors.New("invalid offset in dictionary")
114	}
115	d.content = make([]byte, br.remain())
116	copy(d.content, br.unread())
117	if d.offsets[0] > len(d.content) || d.offsets[1] > len(d.content) || d.offsets[2] > len(d.content) {
118		return nil, fmt.Errorf("initial offset bigger than dictionary content size %d, offsets: %v", len(d.content), d.offsets)
119	}
120
121	return &d, nil
122}
123