1// Copyright 2015, Joe Tsai. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE.md file.
4
5package bzip2
6
7import "github.com/dsnet/compress/internal/errors"
8
9// moveToFront implements both the MTF and RLE stages of bzip2 at the same time.
10// Any runs of zeros in the encoded output will be replaced by a sequence of
11// RUNA and RUNB symbols are encode the length of the run.
12//
13// The RLE encoding used can actually be encoded to and decoded from using
14// normal two's complement arithmetic. The methodology for doing so is below.
15//
16// Assuming the following:
17//	num: The value being encoded by RLE encoding.
18//	run: A sequence of RUNA and RUNB symbols represented as a binary integer,
19//	where RUNA is the 0 bit, RUNB is the 1 bit, and least-significant RUN
20//	symbols are at the least-significant bit positions.
21//	cnt: The number of RUNA and RUNB symbols.
22//
23// Then the RLE encoding used by bzip2 has this mathematical property:
24//	num+1 == (1<<cnt) | run
25type moveToFront struct {
26	dictBuf [256]uint8
27	dictLen int
28
29	vals    []byte
30	syms    []uint16
31	blkSize int
32}
33
34func (mtf *moveToFront) Init(dict []uint8, blkSize int) {
35	if len(dict) > len(mtf.dictBuf) {
36		panicf(errors.Internal, "alphabet too large")
37	}
38	copy(mtf.dictBuf[:], dict)
39	mtf.dictLen = len(dict)
40	mtf.blkSize = blkSize
41}
42
43func (mtf *moveToFront) Encode(vals []byte) (syms []uint16) {
44	dict := mtf.dictBuf[:mtf.dictLen]
45	syms = mtf.syms[:0]
46
47	if len(vals) > mtf.blkSize {
48		panicf(errors.Internal, "exceeded block size")
49	}
50
51	var lastNum uint32
52	for _, val := range vals {
53		// Normal move-to-front transform.
54		var idx uint8 // Reverse lookup idx in dict
55		for di, dv := range dict {
56			if dv == val {
57				idx = uint8(di)
58				break
59			}
60		}
61		copy(dict[1:], dict[:idx])
62		dict[0] = val
63
64		// Run-length encoding augmentation.
65		if idx == 0 {
66			lastNum++
67			continue
68		}
69		if lastNum > 0 {
70			for rc := lastNum + 1; rc != 1; rc >>= 1 {
71				syms = append(syms, uint16(rc&1))
72			}
73			lastNum = 0
74		}
75		syms = append(syms, uint16(idx)+1)
76	}
77	if lastNum > 0 {
78		for rc := lastNum + 1; rc != 1; rc >>= 1 {
79			syms = append(syms, uint16(rc&1))
80		}
81	}
82	mtf.syms = syms
83	return syms
84}
85
86func (mtf *moveToFront) Decode(syms []uint16) (vals []byte) {
87	dict := mtf.dictBuf[:mtf.dictLen]
88	vals = mtf.vals[:0]
89
90	var lastCnt uint
91	var lastRun uint32
92	for _, sym := range syms {
93		// Run-length encoding augmentation.
94		if sym < 2 {
95			lastRun |= uint32(sym) << lastCnt
96			lastCnt++
97			continue
98		}
99		if lastCnt > 0 {
100			cnt := int((1<<lastCnt)|lastRun) - 1
101			if len(vals)+cnt > mtf.blkSize || lastCnt > 24 {
102				panicf(errors.Corrupted, "run-length decoding exceeded block size")
103			}
104			for i := cnt; i > 0; i-- {
105				vals = append(vals, dict[0])
106			}
107			lastCnt, lastRun = 0, 0
108		}
109
110		// Normal move-to-front transform.
111		val := dict[sym-1] // Forward lookup val in dict
112		copy(dict[1:], dict[:sym-1])
113		dict[0] = val
114
115		if len(vals) >= mtf.blkSize {
116			panicf(errors.Corrupted, "run-length decoding exceeded block size")
117		}
118		vals = append(vals, val)
119	}
120	if lastCnt > 0 {
121		cnt := int((1<<lastCnt)|lastRun) - 1
122		if len(vals)+cnt > mtf.blkSize || lastCnt > 24 {
123			panicf(errors.Corrupted, "run-length decoding exceeded block size")
124		}
125		for i := cnt; i > 0; i-- {
126			vals = append(vals, dict[0])
127		}
128	}
129	mtf.vals = vals
130	return vals
131}
132