1// Copyright 2015, Joe Tsai. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE.md file.
4
5package bzip2
6
7import (
8	"io"
9
10	"github.com/dsnet/compress/internal"
11	"github.com/dsnet/compress/internal/errors"
12	"github.com/dsnet/compress/internal/prefix"
13)
14
15type Writer struct {
16	InputOffset  int64 // Total number of bytes issued to Write
17	OutputOffset int64 // Total number of bytes written to underlying io.Writer
18
19	wr     prefixWriter
20	err    error
21	level  int    // The current compression level
22	wrHdr  bool   // Have we written the stream header?
23	blkCRC uint32 // CRC-32 IEEE of each block
24	endCRC uint32 // Checksum of all blocks using bzip2's custom method
25
26	crc crc
27	rle runLengthEncoding
28	bwt burrowsWheelerTransform
29	mtf moveToFront
30
31	// These fields are allocated with Writer and re-used later.
32	buf         []byte
33	treeSels    []uint8
34	treeSelsMTF []uint8
35	codes2D     [maxNumTrees][maxNumSyms]prefix.PrefixCode
36	codes1D     [maxNumTrees]prefix.PrefixCodes
37	trees1D     [maxNumTrees]prefix.Encoder
38}
39
40type WriterConfig struct {
41	Level int
42
43	_ struct{} // Blank field to prevent unkeyed struct literals
44}
45
46func NewWriter(w io.Writer, conf *WriterConfig) (*Writer, error) {
47	var lvl int
48	if conf != nil {
49		lvl = conf.Level
50	}
51	if lvl == 0 {
52		lvl = DefaultCompression
53	}
54	if lvl < BestSpeed || lvl > BestCompression {
55		return nil, errorf(errors.Invalid, "compression level: %d", lvl)
56	}
57	zw := new(Writer)
58	zw.level = lvl
59	zw.Reset(w)
60	return zw, nil
61}
62
63func (zw *Writer) Reset(w io.Writer) error {
64	*zw = Writer{
65		wr:    zw.wr,
66		level: zw.level,
67
68		rle: zw.rle,
69		bwt: zw.bwt,
70		mtf: zw.mtf,
71
72		buf:         zw.buf,
73		treeSels:    zw.treeSels,
74		treeSelsMTF: zw.treeSelsMTF,
75		trees1D:     zw.trees1D,
76	}
77	zw.wr.Init(w)
78	if len(zw.buf) != zw.level*blockSize {
79		zw.buf = make([]byte, zw.level*blockSize)
80	}
81	zw.rle.Init(zw.buf)
82	return nil
83}
84
85func (zw *Writer) Write(buf []byte) (int, error) {
86	if zw.err != nil {
87		return 0, zw.err
88	}
89
90	cnt := len(buf)
91	for {
92		wrCnt, err := zw.rle.Write(buf)
93		if err != rleDone && zw.err == nil {
94			zw.err = err
95		}
96		zw.crc.update(buf[:wrCnt])
97		buf = buf[wrCnt:]
98		if len(buf) == 0 {
99			zw.InputOffset += int64(cnt)
100			return cnt, nil
101		}
102		if zw.err = zw.flush(); zw.err != nil {
103			return 0, zw.err
104		}
105	}
106}
107
108func (zw *Writer) flush() error {
109	vals := zw.rle.Bytes()
110	if len(vals) == 0 {
111		return nil
112	}
113	zw.wr.Offset = zw.OutputOffset
114	func() {
115		defer errors.Recover(&zw.err)
116		if !zw.wrHdr {
117			// Write stream header.
118			zw.wr.WriteBitsBE64(hdrMagic, 16)
119			zw.wr.WriteBitsBE64('h', 8)
120			zw.wr.WriteBitsBE64(uint64('0'+zw.level), 8)
121			zw.wrHdr = true
122		}
123		zw.encodeBlock(vals)
124	}()
125	var err error
126	if zw.OutputOffset, err = zw.wr.Flush(); zw.err == nil {
127		zw.err = err
128	}
129	if zw.err != nil {
130		zw.err = errWrap(zw.err, errors.Internal)
131		return zw.err
132	}
133	zw.endCRC = (zw.endCRC<<1 | zw.endCRC>>31) ^ zw.blkCRC
134	zw.blkCRC = 0
135	zw.rle.Init(zw.buf)
136	return nil
137}
138
139func (zw *Writer) Close() error {
140	if zw.err == errClosed {
141		return nil
142	}
143
144	// Flush RLE buffer if there is left-over data.
145	if zw.err = zw.flush(); zw.err != nil {
146		return zw.err
147	}
148
149	// Write stream footer.
150	zw.wr.Offset = zw.OutputOffset
151	func() {
152		defer errors.Recover(&zw.err)
153		if !zw.wrHdr {
154			// Write stream header.
155			zw.wr.WriteBitsBE64(hdrMagic, 16)
156			zw.wr.WriteBitsBE64('h', 8)
157			zw.wr.WriteBitsBE64(uint64('0'+zw.level), 8)
158			zw.wrHdr = true
159		}
160		zw.wr.WriteBitsBE64(endMagic, 48)
161		zw.wr.WriteBitsBE64(uint64(zw.endCRC), 32)
162		zw.wr.WritePads(0)
163	}()
164	var err error
165	if zw.OutputOffset, err = zw.wr.Flush(); zw.err == nil {
166		zw.err = err
167	}
168	if zw.err != nil {
169		zw.err = errWrap(zw.err, errors.Internal)
170		return zw.err
171	}
172
173	zw.err = errClosed
174	return nil
175}
176
177func (zw *Writer) encodeBlock(buf []byte) {
178	zw.blkCRC = zw.crc.val
179	zw.wr.WriteBitsBE64(blkMagic, 48)
180	zw.wr.WriteBitsBE64(uint64(zw.blkCRC), 32)
181	zw.wr.WriteBitsBE64(0, 1)
182	zw.crc.val = 0
183
184	// Step 1: Burrows-Wheeler transformation.
185	ptr := zw.bwt.Encode(buf)
186	zw.wr.WriteBitsBE64(uint64(ptr), 24)
187
188	// Step 2: Move-to-front transform and run-length encoding.
189	var dictMap [256]bool
190	for _, c := range buf {
191		dictMap[c] = true
192	}
193
194	var dictArr [256]uint8
195	var bmapLo [16]uint16
196	dict := dictArr[:0]
197	bmapHi := uint16(0)
198	for i, b := range dictMap {
199		if b {
200			c := uint8(i)
201			dict = append(dict, c)
202			bmapHi |= 1 << (c >> 4)
203			bmapLo[c>>4] |= 1 << (c & 0xf)
204		}
205	}
206
207	zw.wr.WriteBits(uint(bmapHi), 16)
208	for _, m := range bmapLo {
209		if m > 0 {
210			zw.wr.WriteBits(uint(m), 16)
211		}
212	}
213
214	zw.mtf.Init(dict, len(buf))
215	syms := zw.mtf.Encode(buf)
216
217	// Step 3: Prefix encoding.
218	zw.encodePrefix(syms, len(dict))
219}
220
221func (zw *Writer) encodePrefix(syms []uint16, numSyms int) {
222	numSyms += 2 // Remove 0 symbol, add RUNA, RUNB, and EOB symbols
223	if numSyms < 3 {
224		panicf(errors.Internal, "unable to encode EOB marker")
225	}
226	syms = append(syms, uint16(numSyms-1)) // EOB marker
227
228	// Compute number of prefix trees needed.
229	numTrees := maxNumTrees
230	for i, lim := range []int{200, 600, 1200, 2400} {
231		if len(syms) < lim {
232			numTrees = minNumTrees + i
233			break
234		}
235	}
236
237	// Compute number of block selectors.
238	numSels := (len(syms) + numBlockSyms - 1) / numBlockSyms
239	if cap(zw.treeSels) < numSels {
240		zw.treeSels = make([]uint8, numSels)
241	}
242	treeSels := zw.treeSels[:numSels]
243	for i := range treeSels {
244		treeSels[i] = uint8(i % numTrees)
245	}
246
247	// Initialize prefix codes.
248	for i := range zw.codes2D[:numTrees] {
249		pc := zw.codes2D[i][:numSyms]
250		for j := range pc {
251			pc[j] = prefix.PrefixCode{Sym: uint32(j)}
252		}
253		zw.codes1D[i] = pc
254	}
255
256	// First cut at assigning prefix trees to each group.
257	var codes prefix.PrefixCodes
258	var blkLen, selIdx int
259	for _, sym := range syms {
260		if blkLen == 0 {
261			blkLen = numBlockSyms
262			codes = zw.codes2D[treeSels[selIdx]][:numSyms]
263			selIdx++
264		}
265		blkLen--
266		codes[sym].Cnt++
267	}
268
269	// TODO(dsnet): Use K-means to cluster groups to each prefix tree.
270
271	// Generate lengths and prefixes based on symbol frequencies.
272	for i := range zw.trees1D[:numTrees] {
273		pc := prefix.PrefixCodes(zw.codes2D[i][:numSyms])
274		pc.SortByCount()
275		if err := prefix.GenerateLengths(pc, maxPrefixBits); err != nil {
276			errors.Panic(err)
277		}
278		pc.SortBySymbol()
279	}
280
281	// Write out information about the trees and tree selectors.
282	var mtf internal.MoveToFront
283	zw.wr.WriteBitsBE64(uint64(numTrees), 3)
284	zw.wr.WriteBitsBE64(uint64(numSels), 15)
285	zw.treeSelsMTF = append(zw.treeSelsMTF[:0], treeSels...)
286	mtf.Encode(zw.treeSelsMTF)
287	for _, sym := range zw.treeSelsMTF {
288		zw.wr.WriteSymbol(uint(sym), &encSel)
289	}
290	zw.wr.WritePrefixCodes(zw.codes1D[:numTrees], zw.trees1D[:numTrees])
291
292	// Write out prefix encoded symbols of compressed data.
293	var tree *prefix.Encoder
294	blkLen, selIdx = 0, 0
295	for _, sym := range syms {
296		if blkLen == 0 {
297			blkLen = numBlockSyms
298			tree = &zw.trees1D[treeSels[selIdx]]
299			selIdx++
300		}
301		blkLen--
302		ok := zw.wr.TryWriteSymbol(uint(sym), tree)
303		if !ok {
304			zw.wr.WriteSymbol(uint(sym), tree)
305		}
306	}
307}
308