1package govarint
2
3import "encoding/binary"
4import "io"
5
6type U32VarintEncoder interface {
7	PutU32(x uint32) int
8	Close()
9}
10
11type U32VarintDecoder interface {
12	GetU32() (uint32, error)
13}
14
15///
16
17type U64VarintEncoder interface {
18	PutU64(x uint64) int
19	Close()
20}
21
22type U64VarintDecoder interface {
23	GetU64() (uint64, error)
24}
25
26///
27
28type U32GroupVarintEncoder struct {
29	w     io.Writer
30	index int
31	store [4]uint32
32	temp  [17]byte
33}
34
35func NewU32GroupVarintEncoder(w io.Writer) *U32GroupVarintEncoder { return &U32GroupVarintEncoder{w: w} }
36
37func (b *U32GroupVarintEncoder) Flush() (int, error) {
38	// TODO: Is it more efficient to have a tailored version that's called only in Close()?
39	// If index is zero, there are no integers to flush
40	if b.index == 0 {
41		return 0, nil
42	}
43	// In the case we're flushing (the group isn't of size four), the non-values should be zero
44	// This ensures the unused entries are all zero in the sizeByte
45	for i := b.index; i < 4; i++ {
46		b.store[i] = 0
47	}
48	length := 1
49	// We need to reset the size byte to zero as we only bitwise OR into it, we don't overwrite it
50	b.temp[0] = 0
51	for i, x := range b.store {
52		size := byte(0)
53		shifts := []byte{24, 16, 8, 0}
54		for _, shift := range shifts {
55			// Always writes at least one byte -- the first one (shift = 0)
56			// Will write more bytes until the rest of the integer is all zeroes
57			if (x>>shift) != 0 || shift == 0 {
58				size += 1
59				b.temp[length] = byte(x >> shift)
60				length += 1
61			}
62		}
63		// We store the size in two of the eight bits in the first byte (sizeByte)
64		// 0 means there is one byte in total, hence why we subtract one from size
65		b.temp[0] |= (size - 1) << (uint8(3-i) * 2)
66	}
67	// If we're flushing without a full group of four, remove the unused bytes we computed
68	// This enables us to realize it's a partial group on decoding thanks to EOF
69	if b.index != 4 {
70		length -= 4 - b.index
71	}
72	_, err := b.w.Write(b.temp[:length])
73	return length, err
74}
75
76func (b *U32GroupVarintEncoder) PutU32(x uint32) (int, error) {
77	bytesWritten := 0
78	b.store[b.index] = x
79	b.index += 1
80	if b.index == 4 {
81		n, err := b.Flush()
82		if err != nil {
83			return n, err
84		}
85		bytesWritten += n
86		b.index = 0
87	}
88	return bytesWritten, nil
89}
90
91func (b *U32GroupVarintEncoder) Close() {
92	// On Close, we flush any remaining values that might not have been in a full group
93	b.Flush()
94}
95
96///
97
98type U32GroupVarintDecoder struct {
99	r        io.ByteReader
100	group    [4]uint32
101	pos      int
102	finished bool
103	capacity int
104}
105
106func NewU32GroupVarintDecoder(r io.ByteReader) *U32GroupVarintDecoder {
107	return &U32GroupVarintDecoder{r: r, pos: 4, capacity: 4}
108}
109
110func (b *U32GroupVarintDecoder) getGroup() error {
111	// We should always receive a sizeByte if there are more values to read
112	sizeByte, err := b.r.ReadByte()
113	if err != nil {
114		return err
115	}
116	// Calculate the size of the four incoming 32 bit integers
117	// 0b00 means 1 byte to read, 0b01 = 2, etc
118	b.group[0] = uint32((sizeByte >> 6) & 3)
119	b.group[1] = uint32((sizeByte >> 4) & 3)
120	b.group[2] = uint32((sizeByte >> 2) & 3)
121	b.group[3] = uint32(sizeByte & 3)
122	//
123	for index, size := range b.group {
124		b.group[index] = 0
125		// Any error that occurs in earlier byte reads should be repeated at the end one
126		// Hence we only catch and report the final ReadByte's error
127		var err error
128		switch size {
129		case 0:
130			var x byte
131			x, err = b.r.ReadByte()
132			b.group[index] = uint32(x)
133		case 1:
134			var x, y byte
135			x, _ = b.r.ReadByte()
136			y, err = b.r.ReadByte()
137			b.group[index] = uint32(x)<<8 | uint32(y)
138		case 2:
139			var x, y, z byte
140			x, _ = b.r.ReadByte()
141			y, _ = b.r.ReadByte()
142			z, err = b.r.ReadByte()
143			b.group[index] = uint32(x)<<16 | uint32(y)<<8 | uint32(z)
144		case 3:
145			var x, y, z, zz byte
146			x, _ = b.r.ReadByte()
147			y, _ = b.r.ReadByte()
148			z, _ = b.r.ReadByte()
149			zz, err = b.r.ReadByte()
150			b.group[index] = uint32(x)<<24 | uint32(y)<<16 | uint32(z)<<8 | uint32(zz)
151		}
152		if err != nil {
153			if err == io.EOF {
154				// If we hit EOF here, we have found a partial group
155				// We've return any valid entries we have read and return EOF once we run out
156				b.capacity = index
157				b.finished = true
158				break
159			} else {
160				return err
161			}
162		}
163	}
164	// Reset the pos pointer to the beginning of the read values
165	b.pos = 0
166	return nil
167}
168
169func (b *U32GroupVarintDecoder) GetU32() (uint32, error) {
170	// Check if we have any more values to give out - if not, let's get them
171	if b.pos == b.capacity {
172		// If finished is set, there is nothing else to do
173		if b.finished {
174			return 0, io.EOF
175		}
176		err := b.getGroup()
177		if err != nil {
178			return 0, err
179		}
180	}
181	// Increment pointer and return the value stored at that point
182	b.pos += 1
183	return b.group[b.pos-1], nil
184}
185
186///
187
188type Base128Encoder struct {
189	w        io.Writer
190	tmpBytes []byte
191}
192
193func NewU32Base128Encoder(w io.Writer) *Base128Encoder {
194	return &Base128Encoder{w: w, tmpBytes: make([]byte, binary.MaxVarintLen32)}
195}
196func NewU64Base128Encoder(w io.Writer) *Base128Encoder {
197	return &Base128Encoder{w: w, tmpBytes: make([]byte, binary.MaxVarintLen64)}
198}
199
200func (b *Base128Encoder) PutU32(x uint32) (int, error) {
201	writtenBytes := binary.PutUvarint(b.tmpBytes, uint64(x))
202	return b.w.Write(b.tmpBytes[:writtenBytes])
203}
204
205func (b *Base128Encoder) PutU64(x uint64) (int, error) {
206	writtenBytes := binary.PutUvarint(b.tmpBytes, x)
207	return b.w.Write(b.tmpBytes[:writtenBytes])
208}
209
210func (b *Base128Encoder) Close() {
211}
212
213///
214
215type Base128Decoder struct {
216	r io.ByteReader
217}
218
219func NewU32Base128Decoder(r io.ByteReader) *Base128Decoder { return &Base128Decoder{r: r} }
220func NewU64Base128Decoder(r io.ByteReader) *Base128Decoder { return &Base128Decoder{r: r} }
221
222func (b *Base128Decoder) GetU32() (uint32, error) {
223	v, err := binary.ReadUvarint(b.r)
224	return uint32(v), err
225}
226
227func (b *Base128Decoder) GetU64() (uint64, error) {
228	return binary.ReadUvarint(b.r)
229}
230