1// Copyright (c) 2010, Andrei Vieru. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package lzma
6
7import (
8	"bufio"
9	"io"
10)
11
12const (
13	kTopValue             = 1 << 24
14	kNumBitModelTotalBits = 11
15	kBitModelTotal        = 1 << kNumBitModelTotalBits
16	kNumMoveBits          = 5
17)
18
19// The actual read interface needed by NewDecoder. If the passed in io.Reader
20// does not also have ReadByte, the NewDecoder will introduce its own buffering.
21//
22type Reader interface {
23	io.Reader
24	ReadByte() (c byte, err error)
25}
26
27type rangeDecoder struct {
28	r      Reader
29	rrange uint32
30	code   uint32
31}
32
33func makeReader(r io.Reader) Reader {
34	if rr, ok := r.(Reader); ok {
35		return rr
36	}
37	return bufio.NewReader(r)
38}
39
40func newRangeDecoder(r io.Reader) *rangeDecoder {
41	rd := &rangeDecoder{
42		r:      makeReader(r),
43		rrange: 0xFFFFFFFF,
44		code:   0,
45	}
46	buf := make([]byte, 5)
47	n, err := rd.r.Read(buf)
48	if err != nil {
49		throw(err)
50	}
51	if n != len(buf) {
52		throw(nReadError)
53	}
54	for i := 0; i < len(buf); i++ {
55		rd.code = rd.code<<8 | uint32(buf[i])
56	}
57	return rd
58}
59
60func (rd *rangeDecoder) decodeDirectBits(numTotalBits uint32) (res uint32) {
61	for i := numTotalBits; i != 0; i-- {
62		rd.rrange >>= 1
63		t := (rd.code - rd.rrange) >> 31
64		rd.code -= rd.rrange & (t - 1)
65		res = res<<1 | (1 - t)
66		if rd.rrange < kTopValue {
67			c, err := rd.r.ReadByte()
68			if err != nil {
69				throw(err)
70			}
71			rd.code = rd.code<<8 | uint32(c)
72			rd.rrange <<= 8
73		}
74	}
75	return
76}
77
78func (rd *rangeDecoder) decodeBit(probs []uint16, index uint32) (res uint32) {
79	prob := probs[index]
80	newBound := (rd.rrange >> kNumBitModelTotalBits) * uint32(prob)
81	if rd.code < newBound {
82		rd.rrange = newBound
83		probs[index] = prob + (kBitModelTotal-prob)>>kNumMoveBits
84		if rd.rrange < kTopValue {
85			b, err := rd.r.ReadByte()
86			if err != nil {
87				throw(err)
88			}
89			rd.code = rd.code<<8 | uint32(b)
90			rd.rrange <<= 8
91		}
92		res = 0
93	} else {
94		rd.rrange -= newBound
95		rd.code -= newBound
96		probs[index] = prob - prob>>kNumMoveBits
97		if rd.rrange < kTopValue {
98			b, err := rd.r.ReadByte()
99			if err != nil {
100				throw(err)
101			}
102			rd.code = rd.code<<8 | uint32(b)
103			rd.rrange <<= 8
104		}
105		res = 1
106	}
107	return
108}
109
110func initBitModels(length uint32) (probs []uint16) {
111	probs = make([]uint16, length)
112	val := uint16(kBitModelTotal) >> 1
113	for i := uint32(0); i < length; i++ {
114		probs[i] = val // 1 << 10
115	}
116	return
117}
118
119const (
120	kNumMoveReducingBits  = 2
121	kNumBitPriceShiftBits = 6
122)
123
124// The actual write interface needed by NewEncoder. If the passed in io.Writer
125// does not also have WriteByte and Flush, the NewEncoder will wrap it into an
126// bufio.Writer.
127//
128type Writer interface {
129	io.Writer
130	Flush() error
131	WriteByte(c byte) error
132}
133
134type rangeEncoder struct {
135	w         Writer
136	low       uint64
137	pos       uint64
138	cacheSize uint32
139	cache     uint32
140	rrange    uint32
141}
142
143func makeWriter(w io.Writer) Writer {
144	if ww, ok := w.(Writer); ok {
145		return ww
146	}
147	return bufio.NewWriter(w)
148}
149
150func newRangeEncoder(w io.Writer) *rangeEncoder {
151	return &rangeEncoder{
152		w:         makeWriter(w),
153		low:       0,
154		pos:       0,
155		cacheSize: 1,
156		cache:     0,
157		rrange:    0xFFFFFFFF,
158	}
159}
160
161func (re *rangeEncoder) flush() {
162	for i := 0; i < 5; i++ {
163		re.shiftLow()
164	}
165	err := re.w.Flush()
166	if err != nil {
167		throw(err)
168	}
169}
170
171func (re *rangeEncoder) shiftLow() {
172	lowHi := uint32(re.low >> 32)
173	if lowHi != 0 || re.low < uint64(0x00000000FF000000) {
174		re.pos += uint64(re.cacheSize)
175		temp := re.cache
176		dwtemp := uint32(1) // execute the loop at least once (do-while)
177		for ; dwtemp != 0; dwtemp = re.cacheSize {
178			err := re.w.WriteByte(byte(temp + lowHi))
179			if err != nil {
180				throw(err)
181			}
182			temp = 0x000000FF
183			re.cacheSize--
184		}
185		re.cache = uint32(re.low) >> 24
186	}
187	re.cacheSize++
188	re.low = uint64(uint32(re.low) << 8)
189}
190
191func (re *rangeEncoder) encodeDirectBits(v, numTotalBits uint32) {
192	for i := numTotalBits - 1; int32(i) >= 0; i-- {
193		re.rrange >>= 1
194		if (v>>i)&1 == 1 {
195			re.low += uint64(re.rrange)
196		}
197		if re.rrange < kTopValue {
198			re.rrange <<= 8
199			re.shiftLow()
200		}
201	}
202}
203
204func (re *rangeEncoder) processedSize() uint64 {
205	return uint64(re.cacheSize) + re.pos + 4
206}
207
208func (re *rangeEncoder) encode(probs []uint16, index, symbol uint32) {
209	prob := probs[index]
210	newBound := (re.rrange >> kNumBitModelTotalBits) * uint32(prob)
211	if symbol == 0 {
212		re.rrange = newBound
213		probs[index] = prob + (kBitModelTotal-prob)>>kNumMoveBits
214	} else {
215		re.low += uint64(newBound) & uint64(0xFFFFFFFFFFFFFFFF)
216		re.rrange -= newBound
217		probs[index] = prob - prob>>kNumMoveBits
218	}
219	if re.rrange < kTopValue {
220		re.rrange <<= 8
221		re.shiftLow()
222	}
223}
224
225var probPrices []uint32 = make([]uint32, kBitModelTotal>>kNumMoveReducingBits) // len(probPrices) = 512
226
227// should be called in the encoder's contructor.
228func initProbPrices() {
229	kNumBits := uint32(kNumBitModelTotalBits - kNumMoveReducingBits)
230	for i := kNumBits - 1; int32(i) >= 0; i-- {
231		start := uint32(1) << (kNumBits - i - 1)
232		end := uint32(1) << (kNumBits - i)
233		for j := start; j < end; j++ {
234			probPrices[j] = i<<kNumBitPriceShiftBits + ((end-j)<<kNumBitPriceShiftBits)>>(kNumBits-i-1)
235		}
236	}
237}
238
239func getPrice(prob uint16, symbol uint32) uint32 {
240	return probPrices[(((uint32(prob)-symbol)^(-symbol))&(uint32(kBitModelTotal)-1))>>kNumMoveReducingBits]
241}
242
243func getPrice0(prob uint16) uint32 {
244	return probPrices[prob>>kNumMoveReducingBits]
245}
246
247func getPrice1(prob uint16) uint32 {
248	return probPrices[(kBitModelTotal-prob)>>kNumMoveReducingBits]
249}
250