1// Copyright 2015, Joe Tsai. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE.md file.
4
5package meta
6
7import (
8	"bytes"
9	"io"
10
11	"github.com/dsnet/compress/internal/errors"
12	"github.com/dsnet/compress/internal/prefix"
13)
14
15// A Writer is an io.Writer that can write XFLATE's meta encoding.
16// The zero value of Writer is valid once Reset is called.
17type Writer struct {
18	InputOffset  int64 // Total number of bytes issued to Write
19	OutputOffset int64 // Total number of bytes written to underlying io.Writer
20	NumBlocks    int64 // Number of blocks encoded
21
22	// FinalMode determines which final bits (if any) to set.
23	// This must be set prior to a call to Close.
24	FinalMode FinalMode
25
26	wr io.Writer
27	bw prefix.Writer // Temporary bit writer
28	bb bytes.Buffer  // Buffer for bw to write into
29
30	buf0s  int               // Number of 0-bits in buf
31	buf1s  int               // Number of 1-bits in buf
32	bufCnt int               // Number of bytes in buf
33	buf    [MaxRawBytes]byte // Buffer to collect raw bytes to be encoded
34	cnts   []int             // Slice of counts (reused to avoid allocations)
35	err    error             // Persistent error
36}
37
38// NewWriter creates a new Writer writing to the given writer.
39// It is the caller's responsibility to call Close to complete the meta stream.
40func NewWriter(wr io.Writer) *Writer {
41	mw := new(Writer)
42	mw.Reset(wr)
43	return mw
44}
45
46// Reset discards the Writer's state and makes it equivalent to the result
47// of a call to NewWriter, but writes to wr instead.
48//
49// This is used to reduce memory allocations.
50func (mw *Writer) Reset(wr io.Writer) {
51	*mw = Writer{
52		wr:   wr,
53		bw:   mw.bw,
54		bb:   mw.bb,
55		cnts: mw.cnts,
56	}
57	return
58}
59
60// Write writes the encoded form of buf to the underlying io.Writer.
61// The Writer may buffer the input in order to produce larger meta blocks.
62func (mw *Writer) Write(buf []byte) (int, error) {
63	if mw.err != nil {
64		return 0, mw.err
65	}
66
67	var wrCnt int
68	for _, b := range buf {
69		zeros, ones := numBits(b)
70
71		// If possible, avoid flushing to maintain high efficiency.
72		if ensured := mw.bufCnt < EnsureRawBytes; ensured {
73			goto skipEncode
74		}
75		if huffLen, _ := mw.computeHuffLen(mw.buf0s+zeros, mw.buf1s+ones); huffLen > 0 {
76			goto skipEncode
77		}
78
79		mw.err = mw.encodeBlock(FinalNil)
80		if mw.err != nil {
81			break
82		}
83
84	skipEncode:
85		mw.buf0s += zeros
86		mw.buf1s += ones
87		mw.buf[mw.bufCnt] = b
88		mw.bufCnt++
89		wrCnt++
90	}
91
92	mw.InputOffset += int64(wrCnt)
93	return wrCnt, mw.err
94}
95
96// Close ends the meta stream and flushes all buffered data.
97// The desired FinalMode must be set prior to calling Close.
98func (mw *Writer) Close() error {
99	if mw.err == errClosed {
100		return nil
101	}
102	if mw.err != nil {
103		return mw.err
104	}
105
106	err := mw.encodeBlock(mw.FinalMode)
107	if err != nil {
108		mw.err = err
109	} else {
110		mw.err = errClosed
111	}
112	mw.wr = nil // Release reference to underlying Writer
113	return err
114}
115
116// computeHuffLen computes the shortest Huffman length to encode the data and
117// reports whether the data bits should be inverted.
118// If the input data is too large, then 0 is returned.
119func (*Writer) computeHuffLen(zeros, ones int) (huffLen uint, inv bool) {
120	if inv = ones > zeros; inv {
121		zeros, ones = ones, zeros
122	}
123	for huffLen = minHuffLen; huffLen <= maxHuffLen; huffLen++ {
124		maxOnes := 1 << huffLen
125		if maxSyms-maxOnes >= zeros+8 && maxOnes >= ones+8 {
126			return huffLen, inv
127		}
128	}
129	return 0, false
130}
131
132// computeCounts computes counts of necessary 0s and 1s to form the data.
133// A positive count of +n means to repeat a '1' bit n times,
134// while a negative count of -n means to repeat a '0' bit n times.
135//
136// For example (LSB on left):
137//	01101011 11100011  =>  [-1, +2, -1, +1, -1, +5, -3, +2]
138func (mw *Writer) computeCounts(buf []byte, maxOnes int, final, invert bool) []int {
139	// Stack copy of buf for safe mutations.
140	var arr [1 + MaxRawBytes]byte
141	copy(arr[1:], buf)
142	flags := &arr[0]
143	buf = arr[1 : 1+len(buf)]
144	if invert {
145		for i, b := range buf {
146			buf[i] = ^b
147		}
148	}
149
150	// Set the flags.
151	*flags |= byte(0) << 0            // Always start with zero bit
152	*flags |= byte(btoi(final)) << 1  // Status bit as final meta block
153	*flags |= byte(btoi(invert)) << 2 // Status bit that data is inverted
154	*flags |= byte(len(buf)) << 3     // Data size
155
156	// Compute the counts.
157	var zeros, ones int
158	cnts, pcnt := mw.cnts[:0], 0
159	for _, b := range arr[:1+len(buf)] {
160		for b := int(b) | (1 << 8); b != 1; b >>= 1 { // Data bits (LSB first)
161			if (b&1 > 0) != (pcnt > 0) {
162				cnts, pcnt = append(cnts, pcnt), 0
163			}
164			pcnt += (b&1)*2 - 1 // Add +1 or -1
165		}
166		b0s, b1s := numBits(b)
167		zeros, ones = zeros+b0s, ones+b1s
168	}
169	if pcnt > 0 {
170		cnts, pcnt = append(cnts, pcnt), 0
171	}
172	pcnt += -1 * (maxSyms - maxOnes - zeros) // Add needed zeros
173	if pcnt < 0 {
174		cnts, pcnt = append(cnts, pcnt), 0
175	}
176	pcnt += +1 * (maxOnes - ones) // Add needed ones (includes EOB)
177	cnts = append(cnts, pcnt)
178
179	mw.cnts = cnts
180	return cnts
181}
182
183// encodeBlock encodes a single meta block from mw.buf into the
184// underlying Writer. The values buf0s and buf1s must accurately reflect
185// what is in buf. If successful, it will clear bufCnt, buf0s, and buf1s.
186// It also manages the statistic variables: OutputOffset and NumBlocks.
187func (mw *Writer) encodeBlock(final FinalMode) (err error) {
188	defer errors.Recover(&err)
189
190	mw.bb.Reset()
191	mw.bw.Init(&mw.bb, false)
192
193	buf := mw.buf[:mw.bufCnt]
194	huffLen, inv := mw.computeHuffLen(mw.buf0s, mw.buf1s)
195	if huffLen == 0 {
196		return errorf(errors.Invalid, "block too large to encode")
197	}
198
199	// Encode header.
200	numHCLen := 4 + (8-huffLen)*2 // Based on XFLATE specification
201	magic := magicVals
202	magic |= uint32(btoi(final == FinalStream)) << 0 // Set final DEFLATE block bit
203	magic |= uint32(numHCLen-4) << 13                // numHCLen: 6..18, always even
204	mw.bw.WriteBits(uint(magic), 32)
205	for i := uint(5); i < numHCLen-1; i++ {
206		mw.bw.WriteBits(0, 3) // Empty HCLen code
207	}
208	mw.bw.WriteBits(2, 3) // Final HCLen code
209	mw.bw.WriteBits(0, 1) // First HLit code must be zero
210
211	// Encode data segment.
212	cnts := mw.computeCounts(buf, 1<<huffLen, final != FinalNil, inv)
213	cnts[0]++ // Remove first zero bit, treated as part of the header
214	val, pre := 0, -1
215	for len(cnts) > 0 {
216		if cnts[0] == 0 {
217			cnts = cnts[1:]
218			continue
219		}
220		sym := btoi(cnts[0] > 0) // If zero:  0, if one:  1
221		cur := sym*2 - 1         // If zero: -1, if one: +1
222		cnt := cur * cnts[0]     // Count as positive integer
223
224		switch {
225		case cur < 0 && cnt >= minRepZero: // Use repeated zero code
226			if val = maxRepZero; val > cnt {
227				val = cnt
228			}
229			if ok := mw.bw.TryWriteSymbol(symRepZero, &encHuff); !ok {
230				mw.bw.WriteSymbol(symRepZero, &encHuff)
231			}
232			if ok := mw.bw.TryWriteBits(uint(val-minRepZero), 7); !ok {
233				mw.bw.WriteBits(uint(val-minRepZero), 7)
234			}
235		case pre == cur && cnt >= minRepLast: // Use repeated last code
236			if val = maxRepLast; val > cnt {
237				val = cnt
238			}
239			if ok := mw.bw.TryWriteSymbol(symRepLast, &encHuff); !ok {
240				mw.bw.WriteSymbol(symRepLast, &encHuff)
241			}
242			if ok := mw.bw.TryWriteBits(uint(val-minRepLast), 2); !ok {
243				mw.bw.WriteBits(uint(val-minRepLast), 2)
244			}
245		default: // Use literal value
246			val = 1
247			if ok := mw.bw.TryWriteSymbol(uint(sym), &encHuff); !ok {
248				mw.bw.WriteSymbol(uint(sym), &encHuff)
249			}
250		}
251
252		cnts[0] -= cur * val // Decrement count
253		pre = cur            // Store previous sign
254	}
255
256	// Encode footer (and update header with known padding size).
257	pads := numPads(uint(mw.bw.BitsWritten()) + 1 + huffLen)
258	mw.bw.WriteBits(0, pads)                 // Pad to nearest byte
259	mw.bw.WriteBits(0, 1)                    // Empty HDistTree
260	mw.bw.WriteBits((1<<huffLen)-1, huffLen) // Encode EOB marker
261
262	mw.bw.Flush()                       // Flush all data to the bytes.Buffer
263	mw.bb.Bytes()[0] |= byte(pads) << 3 // Update NumHLit size
264
265	// Write the encoded block.
266	cnt, err := mw.wr.Write(mw.bb.Bytes())
267	mw.OutputOffset += int64(cnt)
268	if err != nil {
269		return err
270	}
271	mw.bufCnt, mw.buf0s, mw.buf1s = 0, 0, 0
272	mw.NumBlocks++
273	return nil
274}
275