1// Copyright 2009 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package jpeg
6
7import (
8	"io"
9)
10
11// maxCodeLength is the maximum (inclusive) number of bits in a Huffman code.
12const maxCodeLength = 16
13
14// maxNCodes is the maximum (inclusive) number of codes in a Huffman tree.
15const maxNCodes = 256
16
17// lutSize is the log-2 size of the Huffman decoder's look-up table.
18const lutSize = 8
19
20// huffman is a Huffman decoder, specified in section C.
21type huffman struct {
22	// length is the number of codes in the tree.
23	nCodes int32
24	// lut is the look-up table for the next lutSize bits in the bit-stream.
25	// The high 8 bits of the uint16 are the encoded value. The low 8 bits
26	// are 1 plus the code length, or 0 if the value is too large to fit in
27	// lutSize bits.
28	lut [1 << lutSize]uint16
29	// vals are the decoded values, sorted by their encoding.
30	vals [maxNCodes]uint8
31	// minCodes[i] is the minimum code of length i, or -1 if there are no
32	// codes of that length.
33	minCodes [maxCodeLength]int32
34	// maxCodes[i] is the maximum code of length i, or -1 if there are no
35	// codes of that length.
36	maxCodes [maxCodeLength]int32
37	// valsIndices[i] is the index into vals of minCodes[i].
38	valsIndices [maxCodeLength]int32
39}
40
41// errShortHuffmanData means that an unexpected EOF occurred while decoding
42// Huffman data.
43var errShortHuffmanData = FormatError("short Huffman data")
44
45// ensureNBits reads bytes from the byte buffer to ensure that d.bits.n is at
46// least n. For best performance (avoiding function calls inside hot loops),
47// the caller is the one responsible for first checking that d.bits.n < n.
48func (d *decoder) ensureNBits(n int32) error {
49	for {
50		c, err := d.readByteStuffedByte()
51		if err != nil {
52			if err == io.EOF {
53				return errShortHuffmanData
54			}
55			return err
56		}
57		d.bits.a = d.bits.a<<8 | uint32(c)
58		d.bits.n += 8
59		if d.bits.m == 0 {
60			d.bits.m = 1 << 7
61		} else {
62			d.bits.m <<= 8
63		}
64		if d.bits.n >= n {
65			break
66		}
67	}
68	return nil
69}
70
71// receiveExtend is the composition of RECEIVE and EXTEND, specified in section
72// F.2.2.1.
73func (d *decoder) receiveExtend(t uint8) (int32, error) {
74	if d.bits.n < int32(t) {
75		if err := d.ensureNBits(int32(t)); err != nil {
76			return 0, err
77		}
78	}
79	d.bits.n -= int32(t)
80	d.bits.m >>= t
81	s := int32(1) << t
82	x := int32(d.bits.a>>uint8(d.bits.n)) & (s - 1)
83	if x < s>>1 {
84		x += ((-1) << t) + 1
85	}
86	return x, nil
87}
88
89// processDHT processes a Define Huffman Table marker, and initializes a huffman
90// struct from its contents. Specified in section B.2.4.2.
91func (d *decoder) processDHT(n int) error {
92	for n > 0 {
93		if n < 17 {
94			return FormatError("DHT has wrong length")
95		}
96		if err := d.readFull(d.tmp[:17]); err != nil {
97			return err
98		}
99		tc := d.tmp[0] >> 4
100		if tc > maxTc {
101			return FormatError("bad Tc value")
102		}
103		th := d.tmp[0] & 0x0f
104		// The baseline th <= 1 restriction is specified in table B.5.
105		if th > maxTh || (d.baseline && th > 1) {
106			return FormatError("bad Th value")
107		}
108		h := &d.huff[tc][th]
109
110		// Read nCodes and h.vals (and derive h.nCodes).
111		// nCodes[i] is the number of codes with code length i.
112		// h.nCodes is the total number of codes.
113		h.nCodes = 0
114		var nCodes [maxCodeLength]int32
115		for i := range nCodes {
116			nCodes[i] = int32(d.tmp[i+1])
117			h.nCodes += nCodes[i]
118		}
119		if h.nCodes == 0 {
120			return FormatError("Huffman table has zero length")
121		}
122		if h.nCodes > maxNCodes {
123			return FormatError("Huffman table has excessive length")
124		}
125		n -= int(h.nCodes) + 17
126		if n < 0 {
127			return FormatError("DHT has wrong length")
128		}
129		if err := d.readFull(h.vals[:h.nCodes]); err != nil {
130			return err
131		}
132
133		// Derive the look-up table.
134		for i := range h.lut {
135			h.lut[i] = 0
136		}
137		var x, code uint32
138		for i := uint32(0); i < lutSize; i++ {
139			code <<= 1
140			for j := int32(0); j < nCodes[i]; j++ {
141				// The codeLength is 1+i, so shift code by 8-(1+i) to
142				// calculate the high bits for every 8-bit sequence
143				// whose codeLength's high bits matches code.
144				// The high 8 bits of lutValue are the encoded value.
145				// The low 8 bits are 1 plus the codeLength.
146				base := uint8(code << (7 - i))
147				lutValue := uint16(h.vals[x])<<8 | uint16(2+i)
148				for k := uint8(0); k < 1<<(7-i); k++ {
149					h.lut[base|k] = lutValue
150				}
151				code++
152				x++
153			}
154		}
155
156		// Derive minCodes, maxCodes, and valsIndices.
157		var c, index int32
158		for i, n := range nCodes {
159			if n == 0 {
160				h.minCodes[i] = -1
161				h.maxCodes[i] = -1
162				h.valsIndices[i] = -1
163			} else {
164				h.minCodes[i] = c
165				h.maxCodes[i] = c + n - 1
166				h.valsIndices[i] = index
167				c += n
168				index += n
169			}
170			c <<= 1
171		}
172	}
173	return nil
174}
175
176// decodeHuffman returns the next Huffman-coded value from the bit-stream,
177// decoded according to h.
178func (d *decoder) decodeHuffman(h *huffman) (uint8, error) {
179	if h.nCodes == 0 {
180		return 0, FormatError("uninitialized Huffman table")
181	}
182
183	if d.bits.n < 8 {
184		if err := d.ensureNBits(8); err != nil {
185			if err != errMissingFF00 && err != errShortHuffmanData {
186				return 0, err
187			}
188			// There are no more bytes of data in this segment, but we may still
189			// be able to read the next symbol out of the previously read bits.
190			// First, undo the readByte that the ensureNBits call made.
191			if d.bytes.nUnreadable != 0 {
192				d.unreadByteStuffedByte()
193			}
194			goto slowPath
195		}
196	}
197	if v := h.lut[(d.bits.a>>uint32(d.bits.n-lutSize))&0xff]; v != 0 {
198		n := (v & 0xff) - 1
199		d.bits.n -= int32(n)
200		d.bits.m >>= n
201		return uint8(v >> 8), nil
202	}
203
204slowPath:
205	for i, code := 0, int32(0); i < maxCodeLength; i++ {
206		if d.bits.n == 0 {
207			if err := d.ensureNBits(1); err != nil {
208				return 0, err
209			}
210		}
211		if d.bits.a&d.bits.m != 0 {
212			code |= 1
213		}
214		d.bits.n--
215		d.bits.m >>= 1
216		if code <= h.maxCodes[i] {
217			return h.vals[h.valsIndices[i]+code-h.minCodes[i]], nil
218		}
219		code <<= 1
220	}
221	return 0, FormatError("bad Huffman code")
222}
223
224func (d *decoder) decodeBit() (bool, error) {
225	if d.bits.n == 0 {
226		if err := d.ensureNBits(1); err != nil {
227			return false, err
228		}
229	}
230	ret := d.bits.a&d.bits.m != 0
231	d.bits.n--
232	d.bits.m >>= 1
233	return ret, nil
234}
235
236func (d *decoder) decodeBits(n int32) (uint32, error) {
237	if d.bits.n < n {
238		if err := d.ensureNBits(n); err != nil {
239			return 0, err
240		}
241	}
242	ret := d.bits.a >> uint32(d.bits.n-n)
243	ret &= (1 << uint32(n)) - 1
244	d.bits.n -= n
245	d.bits.m >>= uint32(n)
246	return ret, nil
247}
248