1package zstd
2
3import (
4	"bytes"
5	"encoding/binary"
6	"errors"
7	"fmt"
8	"io"
9
10	"github.com/klauspost/compress/huff0"
11)
12
13type dict struct {
14	id uint32
15
16	litDec              *huff0.Scratch
17	llDec, ofDec, mlDec sequenceDec
18	offsets             [3]int
19	content             []byte
20}
21
22var dictMagic = [4]byte{0x37, 0xa4, 0x30, 0xec}
23
24// Load a dictionary as described in
25// https://github.com/facebook/zstd/blob/master/doc/zstd_compression_format.md#dictionary-format
26func loadDict(b []byte) (*dict, error) {
27	// Check static field size.
28	if len(b) <= 8+(3*4) {
29		return nil, io.ErrUnexpectedEOF
30	}
31	d := dict{
32		llDec: sequenceDec{fse: &fseDecoder{}},
33		ofDec: sequenceDec{fse: &fseDecoder{}},
34		mlDec: sequenceDec{fse: &fseDecoder{}},
35	}
36	if !bytes.Equal(b[:4], dictMagic[:]) {
37		return nil, ErrMagicMismatch
38	}
39	d.id = binary.LittleEndian.Uint32(b[4:8])
40	if d.id == 0 {
41		return nil, errors.New("dictionaries cannot have ID 0")
42	}
43
44	// Read literal table
45	var err error
46	d.litDec, b, err = huff0.ReadTable(b[8:], nil)
47	if err != nil {
48		return nil, err
49	}
50
51	br := byteReader{
52		b:   b,
53		off: 0,
54	}
55	readDec := func(i tableIndex, dec *fseDecoder) error {
56		if err := dec.readNCount(&br, uint16(maxTableSymbol[i])); err != nil {
57			return err
58		}
59		if br.overread() {
60			return io.ErrUnexpectedEOF
61		}
62		err = dec.transform(symbolTableX[i])
63		if err != nil {
64			println("Transform table error:", err)
65			return err
66		}
67		if debug {
68			println("Read table ok", "symbolLen:", dec.symbolLen)
69		}
70		// Set decoders as predefined so they aren't reused.
71		dec.preDefined = true
72		return nil
73	}
74
75	if err := readDec(tableOffsets, d.ofDec.fse); err != nil {
76		return nil, err
77	}
78	if err := readDec(tableMatchLengths, d.mlDec.fse); err != nil {
79		return nil, err
80	}
81	if err := readDec(tableLiteralLengths, d.llDec.fse); err != nil {
82		return nil, err
83	}
84	if br.remain() < 12 {
85		return nil, io.ErrUnexpectedEOF
86	}
87
88	d.offsets[0] = int(br.Uint32())
89	br.advance(4)
90	d.offsets[1] = int(br.Uint32())
91	br.advance(4)
92	d.offsets[2] = int(br.Uint32())
93	br.advance(4)
94	if d.offsets[0] <= 0 || d.offsets[1] <= 0 || d.offsets[2] <= 0 {
95		return nil, errors.New("invalid offset in dictionary")
96	}
97	d.content = make([]byte, br.remain())
98	copy(d.content, br.unread())
99	if d.offsets[0] > len(d.content) || d.offsets[1] > len(d.content) || d.offsets[2] > len(d.content) {
100		return nil, fmt.Errorf("initial offset bigger than dictionary content size %d, offsets: %v", len(d.content), d.offsets)
101	}
102
103	return &d, nil
104}
105