1// Copyright 2015, Joe Tsai. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE.md file.
4
5package prefix
6
7type RangeCode struct {
8	Base uint32 // Starting base offset of the range
9	Len  uint32 // Bit-length of a subsequent integer to add to base offset
10}
11type RangeCodes []RangeCode
12
13type RangeEncoder struct {
14	rcs     RangeCodes
15	lut     [1024]uint32
16	minBase uint
17}
18
19// End reports the non-inclusive ending range.
20func (rc RangeCode) End() uint32 { return rc.Base + (1 << rc.Len) }
21
22// MakeRangeCodes creates a RangeCodes, where each region is assumed to be
23// contiguously stacked, without any gaps, with bit-lengths taken from bits.
24func MakeRangeCodes(minBase uint, bits []uint) (rc RangeCodes) {
25	for _, nb := range bits {
26		rc = append(rc, RangeCode{Base: uint32(minBase), Len: uint32(nb)})
27		minBase += 1 << nb
28	}
29	return rc
30}
31
32// Base reports the inclusive starting range for all ranges.
33func (rcs RangeCodes) Base() uint32 { return rcs[0].Base }
34
35// End reports the non-inclusive ending range for all ranges.
36func (rcs RangeCodes) End() uint32 { return rcs[len(rcs)-1].End() }
37
38// checkValid reports whether the RangeCodes is valid. In order to be valid,
39// the following must hold true:
40//	rcs[i-1].Base <= rcs[i].Base
41//	rcs[i-1].End  <= rcs[i].End
42//	rcs[i-1].End  >= rcs[i].Base
43//
44// Practically speaking, each range must be increasing and must not have any
45// gaps in between. It is okay for ranges to overlap.
46func (rcs RangeCodes) checkValid() bool {
47	if len(rcs) == 0 {
48		return false
49	}
50	pre := rcs[0]
51	for _, cur := range rcs[1:] {
52		preBase, preEnd := pre.Base, pre.End()
53		curBase, curEnd := cur.Base, cur.End()
54		if preBase > curBase || preEnd > curEnd || preEnd < curBase {
55			return false
56		}
57		pre = cur
58	}
59	return true
60}
61
62func (re *RangeEncoder) Init(rcs RangeCodes) {
63	if !rcs.checkValid() {
64		panic("invalid range codes")
65	}
66	*re = RangeEncoder{rcs: rcs, minBase: uint(rcs.Base())}
67	for sym, rc := range rcs {
68		base := int(rc.Base) - int(re.minBase)
69		end := int(rc.End()) - int(re.minBase)
70		if base >= len(re.lut) {
71			break
72		}
73		if end > len(re.lut) {
74			end = len(re.lut)
75		}
76		for i := base; i < end; i++ {
77			re.lut[i] = uint32(sym)
78		}
79	}
80}
81
82func (re *RangeEncoder) Encode(offset uint) (sym uint) {
83	if idx := int(offset - re.minBase); idx < len(re.lut) {
84		return uint(re.lut[idx])
85	}
86	sym = uint(re.lut[len(re.lut)-1])
87retry:
88	if int(sym) >= len(re.rcs) || re.rcs[sym].Base > uint32(offset) {
89		return sym - 1
90	}
91	sym++
92	goto retry // Avoid for-loop so that this function can be inlined
93}
94