1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9// http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing, software
12// distributed under the License is distributed on an "AS IS" BASIS,
13// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14// See the License for the specific language governing permissions and
15// limitations under the License.
16
17// Package utils contains various internal utilities for the parquet library
18// that aren't intended to be exposed to external consumers such as interfaces
19// and bitmap readers/writers including the RLE encoder/decoder and so on.
20package utils
21
22import (
23	"bytes"
24	"encoding/binary"
25	"io"
26	"math"
27
28	"github.com/apache/arrow/go/v6/arrow/bitutil"
29	"github.com/apache/arrow/go/v6/parquet"
30	"golang.org/x/xerrors"
31)
32
33//go:generate go run ../../../arrow/_tools/tmpl/main.go -i -data=physical_types.tmpldata typed_rle_dict.gen.go.tmpl
34
35const (
36	MaxValuesPerLiteralRun = (1 << 6) * 8
37)
38
39func MinBufferSize(bitWidth int) int {
40	maxLiteralRunSize := 1 + bitutil.BytesForBits(int64(MaxValuesPerLiteralRun*bitWidth))
41	maxRepeatedRunSize := binary.MaxVarintLen32 + bitutil.BytesForBits(int64(bitWidth))
42	return int(Max(maxLiteralRunSize, maxRepeatedRunSize))
43}
44
45func MaxBufferSize(width, numValues int) int {
46	bytesPerRun := width
47	numRuns := int(bitutil.BytesForBits(int64(numValues)))
48	literalMaxSize := numRuns + (numRuns * bytesPerRun)
49
50	minRepeatedRunSize := 1 + int(bitutil.BytesForBits(int64(width)))
51	repeatedMaxSize := int(bitutil.BytesForBits(int64(numValues))) * minRepeatedRunSize
52
53	return MaxInt(literalMaxSize, repeatedMaxSize)
54}
55
56// Utility classes to do run length encoding (RLE) for fixed bit width values.  If runs
57// are sufficiently long, RLE is used, otherwise, the values are just bit-packed
58// (literal encoding).
59// For both types of runs, there is a byte-aligned indicator which encodes the length
60// of the run and the type of the run.
61// This encoding has the benefit that when there aren't any long enough runs, values
62// are always decoded at fixed (can be precomputed) bit offsets OR both the value and
63// the run length are byte aligned. This allows for very efficient decoding
64// implementations.
65// The encoding is:
66//    encoded-block := run*
67//    run := literal-run | repeated-run
68//    literal-run := literal-indicator < literal bytes >
69//    repeated-run := repeated-indicator < repeated value. padded to byte boundary >
70//    literal-indicator := varint_encode( number_of_groups << 1 | 1)
71//    repeated-indicator := varint_encode( number_of_repetitions << 1 )
72//
73// Each run is preceded by a varint. The varint's least significant bit is
74// used to indicate whether the run is a literal run or a repeated run. The rest
75// of the varint is used to determine the length of the run (eg how many times the
76// value repeats).
77//
78// In the case of literal runs, the run length is always a multiple of 8 (i.e. encode
79// in groups of 8), so that no matter the bit-width of the value, the sequence will end
80// on a byte boundary without padding.
81// Given that we know it is a multiple of 8, we store the number of 8-groups rather than
82// the actual number of encoded ints. (This means that the total number of encoded values
83// can not be determined from the encoded data, since the number of values in the last
84// group may not be a multiple of 8). For the last group of literal runs, we pad
85// the group to 8 with zeros. This allows for 8 at a time decoding on the read side
86// without the need for additional checks.
87//
88// There is a break-even point when it is more storage efficient to do run length
89// encoding.  For 1 bit-width values, that point is 8 values.  They require 2 bytes
90// for both the repeated encoding or the literal encoding.  This value can always
91// be computed based on the bit-width.
92//
93// Examples with bit-width 1 (eg encoding booleans):
94// ----------------------------------------
95// 100 1s followed by 100 0s:
96// <varint(100 << 1)> <1, padded to 1 byte> <varint(100 << 1)> <0, padded to 1 byte>
97//  - (total 4 bytes)
98//
99// alternating 1s and 0s (200 total):
100// 200 ints = 25 groups of 8
101// <varint((25 << 1) | 1)> <25 bytes of values, bitpacked>
102// (total 26 bytes, 1 byte overhead)
103//
104
105type RleDecoder struct {
106	r *BitReader
107
108	bitWidth int
109	curVal   uint64
110	repCount int32
111	litCount int32
112}
113
114func NewRleDecoder(data *bytes.Reader, width int) *RleDecoder {
115	return &RleDecoder{r: NewBitReader(data), bitWidth: width}
116}
117
118func (r *RleDecoder) Reset(data *bytes.Reader, width int) {
119	r.bitWidth = width
120	r.curVal = 0
121	r.repCount = 0
122	r.litCount = 0
123	r.r.Reset(data)
124}
125
126func (r *RleDecoder) Next() bool {
127	indicator, ok := r.r.GetVlqInt()
128	if !ok {
129		return false
130	}
131
132	literal := (indicator & 1) != 0
133	count := uint32(indicator >> 1)
134	if literal {
135		if count == 0 || count > uint32(math.MaxInt32/8) {
136			return false
137		}
138		r.litCount = int32(count) * 8
139	} else {
140		if count == 0 || count > uint32(math.MaxInt32) {
141			return false
142		}
143		r.repCount = int32(count)
144
145		nbytes := int(bitutil.BytesForBits(int64(r.bitWidth)))
146		switch {
147		case nbytes > 4:
148			if !r.r.GetAligned(nbytes, &r.curVal) {
149				return false
150			}
151		case nbytes > 2:
152			var val uint32
153			if !r.r.GetAligned(nbytes, &val) {
154				return false
155			}
156			r.curVal = uint64(val)
157		case nbytes > 1:
158			var val uint16
159			if !r.r.GetAligned(nbytes, &val) {
160				return false
161			}
162			r.curVal = uint64(val)
163		default:
164			var val uint8
165			if !r.r.GetAligned(nbytes, &val) {
166				return false
167			}
168			r.curVal = uint64(val)
169		}
170	}
171	return true
172}
173
174func (r *RleDecoder) GetValue() (uint64, bool) {
175	vals := make([]uint64, 1)
176	n := r.GetBatch(vals)
177	return vals[0], n == 1
178}
179
180func (r *RleDecoder) GetBatch(values []uint64) int {
181	read := 0
182	size := len(values)
183
184	out := values
185	for read < size {
186		remain := size - read
187
188		if r.repCount > 0 {
189			repbatch := int(math.Min(float64(remain), float64(r.repCount)))
190			for i := 0; i < repbatch; i++ {
191				out[i] = r.curVal
192			}
193
194			r.repCount -= int32(repbatch)
195			read += repbatch
196			out = out[repbatch:]
197		} else if r.litCount > 0 {
198			litbatch := int(math.Min(float64(remain), float64(r.litCount)))
199			n, _ := r.r.GetBatch(uint(r.bitWidth), out[:litbatch])
200			if n != litbatch {
201				return read
202			}
203
204			r.litCount -= int32(litbatch)
205			read += litbatch
206			out = out[litbatch:]
207		} else {
208			if !r.Next() {
209				return read
210			}
211		}
212	}
213	return read
214}
215
216func (r *RleDecoder) GetBatchSpaced(vals []uint64, nullcount int, validBits []byte, validBitsOffset int64) (int, error) {
217	if nullcount == 0 {
218		return r.GetBatch(vals), nil
219	}
220
221	converter := plainConverter{}
222	blockCounter := NewBitBlockCounter(validBits, validBitsOffset, int64(len(vals)))
223
224	var (
225		totalProcessed int
226		processed      int
227		block          BitBlockCount
228		err            error
229	)
230
231	for {
232		block = blockCounter.NextFourWords()
233		if block.Len == 0 {
234			break
235		}
236
237		if block.AllSet() {
238			processed = r.GetBatch(vals[:block.Len])
239		} else if block.NoneSet() {
240			converter.FillZero(vals[:block.Len])
241			processed = int(block.Len)
242		} else {
243			processed, err = r.getspaced(converter, vals, int(block.Len), int(block.Len-block.Popcnt), validBits, validBitsOffset)
244			if err != nil {
245				return totalProcessed, err
246			}
247		}
248
249		totalProcessed += processed
250		vals = vals[int(block.Len):]
251		validBitsOffset += int64(block.Len)
252
253		if processed != int(block.Len) {
254			break
255		}
256	}
257	return totalProcessed, nil
258}
259
260func (r *RleDecoder) getspaced(dc DictionaryConverter, vals interface{}, batchSize, nullCount int, validBits []byte, validBitsOffset int64) (int, error) {
261	switch vals := vals.(type) {
262	case []int32:
263		return r.getspacedInt32(dc, vals, batchSize, nullCount, validBits, validBitsOffset)
264	case []int64:
265		return r.getspacedInt64(dc, vals, batchSize, nullCount, validBits, validBitsOffset)
266	case []float32:
267		return r.getspacedFloat32(dc, vals, batchSize, nullCount, validBits, validBitsOffset)
268	case []float64:
269		return r.getspacedFloat64(dc, vals, batchSize, nullCount, validBits, validBitsOffset)
270	case []parquet.ByteArray:
271		return r.getspacedByteArray(dc, vals, batchSize, nullCount, validBits, validBitsOffset)
272	case []parquet.FixedLenByteArray:
273		return r.getspacedFixedLenByteArray(dc, vals, batchSize, nullCount, validBits, validBitsOffset)
274	case []parquet.Int96:
275		return r.getspacedInt96(dc, vals, batchSize, nullCount, validBits, validBitsOffset)
276	case []uint64:
277		return r.getspacedUint64(dc, vals, batchSize, nullCount, validBits, validBitsOffset)
278	default:
279		return 0, xerrors.New("parquet/rle: getspaced invalid type")
280	}
281}
282
283func (r *RleDecoder) getspacedUint64(dc DictionaryConverter, vals []uint64, batchSize, nullCount int, validBits []byte, validBitsOffset int64) (int, error) {
284	if nullCount == batchSize {
285		dc.FillZero(vals[:batchSize])
286		return batchSize, nil
287	}
288
289	read := 0
290	remain := batchSize - nullCount
291
292	const bufferSize = 1024
293	var indexbuffer [bufferSize]IndexType
294
295	// assume no bits to start
296	bitReader := NewBitRunReader(validBits, validBitsOffset, int64(batchSize))
297	validRun := bitReader.NextRun()
298	for read < batchSize {
299		if validRun.Len == 0 {
300			validRun = bitReader.NextRun()
301		}
302
303		if !validRun.Set {
304			dc.FillZero(vals[:int(validRun.Len)])
305			vals = vals[int(validRun.Len):]
306			read += int(validRun.Len)
307			validRun.Len = 0
308			continue
309		}
310
311		if r.repCount == 0 && r.litCount == 0 {
312			if !r.Next() {
313				return read, nil
314			}
315		}
316
317		var batch int
318		switch {
319		case r.repCount > 0:
320			batch, remain, validRun = r.consumeRepeatCounts(read, batchSize, remain, validRun, bitReader)
321			current := IndexType(r.curVal)
322			if !dc.IsValid(current) {
323				return read, nil
324			}
325			dc.Fill(vals[:batch], current)
326		case r.litCount > 0:
327			var (
328				litread int
329				skipped int
330				err     error
331			)
332			litread, skipped, validRun, err = r.consumeLiteralsUint64(dc, vals, remain, indexbuffer[:], validRun, bitReader)
333			if err != nil {
334				return read, err
335			}
336			batch = litread + skipped
337			remain -= litread
338		}
339
340		vals = vals[batch:]
341		read += batch
342	}
343	return read, nil
344}
345
346func (r *RleDecoder) consumeRepeatCounts(read, batchSize, remain int, run BitRun, bitRdr BitRunReader) (int, int, BitRun) {
347	// Consume the entire repeat counts incrementing repeat_batch to
348	// be the total of nulls + values consumed, we only need to
349	// get the total count because we can fill in the same value for
350	// nulls and non-nulls. This proves to be a big efficiency win.
351	repeatBatch := 0
352	for r.repCount > 0 && (read+repeatBatch) < batchSize {
353		if run.Set {
354			updateSize := int(Min(run.Len, int64(r.repCount)))
355			r.repCount -= int32(updateSize)
356			repeatBatch += updateSize
357			run.Len -= int64(updateSize)
358			remain -= updateSize
359		} else {
360			repeatBatch += int(run.Len)
361			run.Len = 0
362		}
363
364		if run.Len == 0 {
365			run = bitRdr.NextRun()
366		}
367	}
368	return repeatBatch, remain, run
369}
370
371func (r *RleDecoder) consumeLiteralsUint64(dc DictionaryConverter, vals []uint64, remain int, buf []IndexType, run BitRun, bitRdr BitRunReader) (int, int, BitRun, error) {
372	batch := MinInt(MinInt(remain, int(r.litCount)), len(buf))
373	buf = buf[:batch]
374
375	n, _ := r.r.GetBatchIndex(uint(r.bitWidth), buf)
376	if n != batch {
377		return 0, 0, run, xerrors.New("was not able to retrieve correct number of indexes")
378	}
379
380	if !dc.IsValid(buf...) {
381		return 0, 0, run, xerrors.New("invalid index values found for dictionary converter")
382	}
383
384	var (
385		read    int
386		skipped int
387	)
388	for read < batch {
389		if run.Set {
390			updateSize := MinInt(batch-read, int(run.Len))
391			if err := dc.Copy(vals, buf[read:read+updateSize]); err != nil {
392				return 0, 0, run, err
393			}
394			read += updateSize
395			vals = vals[updateSize:]
396			run.Len -= int64(updateSize)
397		} else {
398			dc.FillZero(vals[:int(run.Len)])
399			vals = vals[int(run.Len):]
400			skipped += int(run.Len)
401			run.Len = 0
402		}
403		if run.Len == 0 {
404			run = bitRdr.NextRun()
405		}
406	}
407	r.litCount -= int32(batch)
408	return read, skipped, run, nil
409}
410
411func (r *RleDecoder) GetBatchWithDict(dc DictionaryConverter, vals interface{}) (int, error) {
412	switch vals := vals.(type) {
413	case []int32:
414		return r.GetBatchWithDictInt32(dc, vals)
415	case []int64:
416		return r.GetBatchWithDictInt64(dc, vals)
417	case []float32:
418		return r.GetBatchWithDictFloat32(dc, vals)
419	case []float64:
420		return r.GetBatchWithDictFloat64(dc, vals)
421	case []parquet.ByteArray:
422		return r.GetBatchWithDictByteArray(dc, vals)
423	case []parquet.FixedLenByteArray:
424		return r.GetBatchWithDictFixedLenByteArray(dc, vals)
425	case []parquet.Int96:
426		return r.GetBatchWithDictInt96(dc, vals)
427	default:
428		return 0, xerrors.New("parquet/rle: GetBatchWithDict invalid type")
429	}
430}
431
432func (r *RleDecoder) GetBatchWithDictSpaced(dc DictionaryConverter, vals interface{}, nullCount int, validBits []byte, validBitsOffset int64) (int, error) {
433	switch vals := vals.(type) {
434	case []int32:
435		return r.GetBatchWithDictSpacedInt32(dc, vals, nullCount, validBits, validBitsOffset)
436	case []int64:
437		return r.GetBatchWithDictSpacedInt64(dc, vals, nullCount, validBits, validBitsOffset)
438	case []float32:
439		return r.GetBatchWithDictSpacedFloat32(dc, vals, nullCount, validBits, validBitsOffset)
440	case []float64:
441		return r.GetBatchWithDictSpacedFloat64(dc, vals, nullCount, validBits, validBitsOffset)
442	case []parquet.ByteArray:
443		return r.GetBatchWithDictSpacedByteArray(dc, vals, nullCount, validBits, validBitsOffset)
444	case []parquet.FixedLenByteArray:
445		return r.GetBatchWithDictSpacedFixedLenByteArray(dc, vals, nullCount, validBits, validBitsOffset)
446	case []parquet.Int96:
447		return r.GetBatchWithDictSpacedInt96(dc, vals, nullCount, validBits, validBitsOffset)
448	default:
449		return 0, xerrors.New("parquet/rle: GetBatchWithDictSpaced invalid type")
450	}
451}
452
453type RleEncoder struct {
454	w *BitWriter
455
456	buffer                 []uint64
457	BitWidth               int
458	curVal                 uint64
459	repCount               int32
460	litCount               int32
461	literalIndicatorOffset int
462
463	indicatorBuffer [1]byte
464}
465
466func NewRleEncoder(w io.WriterAt, width int) *RleEncoder {
467	return &RleEncoder{
468		w:                      NewBitWriter(w),
469		buffer:                 make([]uint64, 0, 8),
470		BitWidth:               width,
471		literalIndicatorOffset: -1,
472	}
473}
474
475func (r *RleEncoder) Flush() int {
476	if r.litCount > 0 || r.repCount > 0 || len(r.buffer) > 0 {
477		allRep := r.litCount == 0 && (r.repCount == int32(len(r.buffer)) || len(r.buffer) == 0)
478		if r.repCount > 0 && allRep {
479			r.flushRepeated()
480		} else {
481			// buffer the last grou pof literals to 8 by padding with 0s
482			for len(r.buffer) != 0 && len(r.buffer) < 8 {
483				r.buffer = append(r.buffer, 0)
484			}
485
486			r.litCount += int32(len(r.buffer))
487			r.flushLiteral(true)
488			r.repCount = 0
489		}
490	}
491	r.w.Flush(false)
492	return r.w.Written()
493}
494
495func (r *RleEncoder) flushBuffered(done bool) (err error) {
496	if r.repCount >= 8 {
497		// clear buffered values. they are part of the repeated run now and we
498		// don't want to flush them as literals
499		r.buffer = r.buffer[:0]
500		if r.litCount != 0 {
501			// there was  current literal run. all values flushed but need to update the indicator
502			err = r.flushLiteral(true)
503		}
504		return
505	}
506
507	r.litCount += int32(len(r.buffer))
508	ngroups := r.litCount / 8
509	if ngroups+1 >= (1 << 6) {
510		// we need to start a new literal run because the indicator byte we've reserved
511		// cannot store any more values
512		err = r.flushLiteral(true)
513	} else {
514		err = r.flushLiteral(done)
515	}
516	r.repCount = 0
517	return
518}
519
520func (r *RleEncoder) flushLiteral(updateIndicator bool) (err error) {
521	if r.literalIndicatorOffset == -1 {
522		r.literalIndicatorOffset = r.w.ReserveBytes(1)
523	}
524
525	for _, val := range r.buffer {
526		if err = r.w.WriteValue(val, uint(r.BitWidth)); err != nil {
527			return
528		}
529	}
530	r.buffer = r.buffer[:0]
531
532	if updateIndicator {
533		// at this point we need to write the indicator byte for the literal run.
534		// we only reserve one byte, to allow for streaming writes of literal values.
535		// the logic makes sure we flush literal runs often enough to not overrun the 1 byte.
536		ngroups := r.litCount / 8
537		r.indicatorBuffer[0] = byte((ngroups << 1) | 1)
538		_, err = r.w.WriteAt(r.indicatorBuffer[:], int64(r.literalIndicatorOffset))
539		r.literalIndicatorOffset = -1
540		r.litCount = 0
541	}
542	return
543}
544
545func (r *RleEncoder) flushRepeated() (ret bool) {
546	indicator := r.repCount << 1
547
548	ret = r.w.WriteVlqInt(uint64(indicator))
549	ret = ret && r.w.WriteAligned(r.curVal, int(bitutil.BytesForBits(int64(r.BitWidth))))
550
551	r.repCount = 0
552	r.buffer = r.buffer[:0]
553	return
554}
555
556// Put buffers input values 8 at a time. after seeing all 8 values,
557// it decides whether they should be encoded as a literal or repeated run.
558func (r *RleEncoder) Put(value uint64) error {
559	if r.curVal == value {
560		r.repCount++
561		if r.repCount > 8 {
562			// this is just a continuation of the current run, no need to buffer the values
563			// NOTE this is the fast path for long repeated runs
564			return nil
565		}
566	} else {
567		if r.repCount >= 8 {
568			if !r.flushRepeated() {
569				return xerrors.New("failed to flush repeated value")
570			}
571		}
572		r.repCount = 1
573		r.curVal = value
574	}
575
576	r.buffer = append(r.buffer, value)
577	if len(r.buffer) == 8 {
578		return r.flushBuffered(false)
579	}
580	return nil
581}
582
583func (r *RleEncoder) Clear() {
584	r.curVal = 0
585	r.repCount = 0
586	r.buffer = r.buffer[:0]
587	r.litCount = 0
588	r.literalIndicatorOffset = -1
589	r.w.Clear()
590}
591