1// Copyright 2014-2019 Ulrich Kunitz. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package lzma
6
7import (
8	"errors"
9	"io"
10)
11
12// rangeEncoder implements range encoding of single bits. The low value can
13// overflow therefore we need uint64. The cache value is used to handle
14// overflows.
15type rangeEncoder struct {
16	lbw      *LimitedByteWriter
17	nrange   uint32
18	low      uint64
19	cacheLen int64
20	cache    byte
21}
22
23// maxInt64 provides the  maximal value of the int64 type
24const maxInt64 = 1<<63 - 1
25
26// newRangeEncoder creates a new range encoder.
27func newRangeEncoder(bw io.ByteWriter) (re *rangeEncoder, err error) {
28	lbw, ok := bw.(*LimitedByteWriter)
29	if !ok {
30		lbw = &LimitedByteWriter{BW: bw, N: maxInt64}
31	}
32	return &rangeEncoder{
33		lbw:      lbw,
34		nrange:   0xffffffff,
35		cacheLen: 1}, nil
36}
37
38// Available returns the number of bytes that still can be written. The
39// method takes the bytes that will be currently written by Close into
40// account.
41func (e *rangeEncoder) Available() int64 {
42	return e.lbw.N - (e.cacheLen + 4)
43}
44
45// writeByte writes a single byte to the underlying writer. An error is
46// returned if the limit is reached. The written byte will be counted if
47// the underlying writer doesn't return an error.
48func (e *rangeEncoder) writeByte(c byte) error {
49	if e.Available() < 1 {
50		return ErrLimit
51	}
52	return e.lbw.WriteByte(c)
53}
54
55// DirectEncodeBit encodes the least-significant bit of b with probability 1/2.
56func (e *rangeEncoder) DirectEncodeBit(b uint32) error {
57	e.nrange >>= 1
58	e.low += uint64(e.nrange) & (0 - (uint64(b) & 1))
59
60	// normalize
61	const top = 1 << 24
62	if e.nrange >= top {
63		return nil
64	}
65	e.nrange <<= 8
66	return e.shiftLow()
67}
68
69// EncodeBit encodes the least significant bit of b. The p value will be
70// updated by the function depending on the bit encoded.
71func (e *rangeEncoder) EncodeBit(b uint32, p *prob) error {
72	bound := p.bound(e.nrange)
73	if b&1 == 0 {
74		e.nrange = bound
75		p.inc()
76	} else {
77		e.low += uint64(bound)
78		e.nrange -= bound
79		p.dec()
80	}
81
82	// normalize
83	const top = 1 << 24
84	if e.nrange >= top {
85		return nil
86	}
87	e.nrange <<= 8
88	return e.shiftLow()
89}
90
91// Close writes a complete copy of the low value.
92func (e *rangeEncoder) Close() error {
93	for i := 0; i < 5; i++ {
94		if err := e.shiftLow(); err != nil {
95			return err
96		}
97	}
98	return nil
99}
100
101// shiftLow shifts the low value for 8 bit. The shifted byte is written into
102// the byte writer. The cache value is used to handle overflows.
103func (e *rangeEncoder) shiftLow() error {
104	if uint32(e.low) < 0xff000000 || (e.low>>32) != 0 {
105		tmp := e.cache
106		for {
107			err := e.writeByte(tmp + byte(e.low>>32))
108			if err != nil {
109				return err
110			}
111			tmp = 0xff
112			e.cacheLen--
113			if e.cacheLen <= 0 {
114				if e.cacheLen < 0 {
115					panic("negative cacheLen")
116				}
117				break
118			}
119		}
120		e.cache = byte(uint32(e.low) >> 24)
121	}
122	e.cacheLen++
123	e.low = uint64(uint32(e.low) << 8)
124	return nil
125}
126
127// rangeDecoder decodes single bits of the range encoding stream.
128type rangeDecoder struct {
129	br     io.ByteReader
130	nrange uint32
131	code   uint32
132}
133
134// init initializes the range decoder, by reading from the byte reader.
135func (d *rangeDecoder) init() error {
136	d.nrange = 0xffffffff
137	d.code = 0
138
139	b, err := d.br.ReadByte()
140	if err != nil {
141		return err
142	}
143	if b != 0 {
144		return errors.New("newRangeDecoder: first byte not zero")
145	}
146
147	for i := 0; i < 4; i++ {
148		if err = d.updateCode(); err != nil {
149			return err
150		}
151	}
152
153	if d.code >= d.nrange {
154		return errors.New("newRangeDecoder: d.code >= d.nrange")
155	}
156
157	return nil
158}
159
160// newRangeDecoder initializes a range decoder. It reads five bytes from the
161// reader and therefore may return an error.
162func newRangeDecoder(br io.ByteReader) (d *rangeDecoder, err error) {
163	d = &rangeDecoder{br: br, nrange: 0xffffffff}
164
165	b, err := d.br.ReadByte()
166	if err != nil {
167		return nil, err
168	}
169	if b != 0 {
170		return nil, errors.New("newRangeDecoder: first byte not zero")
171	}
172
173	for i := 0; i < 4; i++ {
174		if err = d.updateCode(); err != nil {
175			return nil, err
176		}
177	}
178
179	if d.code >= d.nrange {
180		return nil, errors.New("newRangeDecoder: d.code >= d.nrange")
181	}
182
183	return d, nil
184}
185
186// possiblyAtEnd checks whether the decoder may be at the end of the stream.
187func (d *rangeDecoder) possiblyAtEnd() bool {
188	return d.code == 0
189}
190
191// DirectDecodeBit decodes a bit with probability 1/2. The return value b will
192// contain the bit at the least-significant position. All other bits will be
193// zero.
194func (d *rangeDecoder) DirectDecodeBit() (b uint32, err error) {
195	d.nrange >>= 1
196	d.code -= d.nrange
197	t := 0 - (d.code >> 31)
198	d.code += d.nrange & t
199	b = (t + 1) & 1
200
201	// d.code will stay less then d.nrange
202
203	// normalize
204	// assume d.code < d.nrange
205	const top = 1 << 24
206	if d.nrange >= top {
207		return b, nil
208	}
209	d.nrange <<= 8
210	// d.code < d.nrange will be maintained
211	return b, d.updateCode()
212}
213
214// decodeBit decodes a single bit. The bit will be returned at the
215// least-significant position. All other bits will be zero. The probability
216// value will be updated.
217func (d *rangeDecoder) DecodeBit(p *prob) (b uint32, err error) {
218	bound := p.bound(d.nrange)
219	if d.code < bound {
220		d.nrange = bound
221		p.inc()
222		b = 0
223	} else {
224		d.code -= bound
225		d.nrange -= bound
226		p.dec()
227		b = 1
228	}
229	// normalize
230	// assume d.code < d.nrange
231	const top = 1 << 24
232	if d.nrange >= top {
233		return b, nil
234	}
235	d.nrange <<= 8
236	// d.code < d.nrange will be maintained
237	return b, d.updateCode()
238}
239
240// updateCode reads a new byte into the code.
241func (d *rangeDecoder) updateCode() error {
242	b, err := d.br.ReadByte()
243	if err != nil {
244		return err
245	}
246	d.code = (d.code << 8) | uint32(b)
247	return nil
248}
249