1// Copyright 2019 Gregory Petrosyan <gregory.petrosyan@gmail.com>
2//
3// This Source Code Form is subject to the terms of the Mozilla Public
4// License, v. 2.0. If a copy of the MPL was not distributed with this
5// file, You can obtain one at https://mozilla.org/MPL/2.0/.
6
7package rapid
8
9import (
10	"math"
11	"math/bits"
12)
13
14const (
15	biasLabel     = "bias"
16	intBitsLabel  = "intbits"
17	coinFlipLabel = "coinflip"
18	dieRollLabel  = "dieroll"
19	repeatLabel   = "@repeat"
20)
21
22func bitmask64(n uint) uint64 {
23	return uint64(1)<<n - 1
24}
25
26func genFloat01(s bitStream) float64 {
27	return float64(s.drawBits(53)) / (1 << 53)
28}
29
30func genGeom(s bitStream, p float64) uint64 {
31	assert(p > 0 && p <= 1)
32
33	f := genFloat01(s)
34	n := math.Log1p(-f) / math.Log1p(-p)
35
36	return uint64(n)
37}
38
39func genUintNNoReject(s bitStream, max uint64) uint64 {
40	bitlen := bits.Len64(max)
41	i := s.beginGroup(intBitsLabel, false)
42	u := s.drawBits(bitlen)
43	s.endGroup(i, false)
44	if u > max {
45		u = max
46	}
47	return u
48}
49
50func genUintNUnbiased(s bitStream, max uint64) uint64 {
51	bitlen := bits.Len64(max)
52
53	for {
54		i := s.beginGroup(intBitsLabel, false)
55		u := s.drawBits(bitlen)
56		ok := u <= max
57		s.endGroup(i, !ok)
58		if ok {
59			return u
60		}
61	}
62}
63
64func genUintNBiased(s bitStream, max uint64) (uint64, bool, bool) {
65	bitlen := bits.Len64(max)
66	i := s.beginGroup(biasLabel, false)
67	m := math.Max(8, (float64(bitlen)+48)/7)
68	n := genGeom(s, 1/(m+1)) + 1
69	s.endGroup(i, false)
70
71	if int(n) < bitlen {
72		bitlen = int(n)
73	} else if int(n) >= 64-(16-int(m))*4 {
74		bitlen = 65
75	}
76
77	for {
78		i := s.beginGroup(intBitsLabel, false)
79		u := s.drawBits(bitlen)
80		ok := bitlen > 64 || u <= max
81		s.endGroup(i, !ok)
82		if bitlen > 64 {
83			u = max
84		}
85		if u <= max {
86			return u, u == 0 && n == 1, u == max && bitlen >= int(n)
87		}
88	}
89}
90
91func genUintN(s bitStream, max uint64, bias bool) (uint64, bool, bool) {
92	if bias {
93		return genUintNBiased(s, max)
94	} else {
95		return genUintNUnbiased(s, max), false, false
96	}
97}
98
99func genUintRange(s bitStream, min uint64, max uint64, bias bool) (uint64, bool, bool) {
100	assertf(min <= max, "invalid range [%v,  %v]", min, max)
101
102	u, lOverflow, rOverflow := genUintN(s, max-min, bias)
103
104	return min + u, lOverflow, rOverflow
105}
106
107func genIntRange(s bitStream, min int64, max int64, bias bool) (int64, bool, bool) {
108	assertf(min <= max, "invalid range [%v,  %v]", min, max)
109
110	var posMin, negMin uint64
111	var pNeg float64
112	if min >= 0 {
113		posMin = uint64(min)
114		pNeg = 0
115	} else if max <= 0 {
116		negMin = uint64(-max)
117		pNeg = 1
118	} else {
119		posMin = 0
120		negMin = 1
121		pos := uint64(max) + 1
122		neg := uint64(-min)
123		pNeg = float64(neg) / (float64(neg) + float64(pos))
124		if bias {
125			pNeg = 0.5
126		}
127	}
128
129	if flipBiasedCoin(s, pNeg) {
130		u, lOverflow, rOverflow := genUintRange(s, negMin, uint64(-min), bias)
131		return -int64(u), rOverflow, lOverflow && max <= 0
132	} else {
133		u, lOverflow, rOverflow := genUintRange(s, posMin, uint64(max), bias)
134		return int64(u), lOverflow && min >= 0, rOverflow
135	}
136}
137
138func genIndex(s bitStream, n int, bias bool) int {
139	assert(n > 0)
140
141	u, _, _ := genUintN(s, uint64(n-1), bias)
142
143	return int(u)
144}
145
146func flipBiasedCoin(s bitStream, p float64) bool {
147	assert(p >= 0 && p <= 1)
148
149	i := s.beginGroup(coinFlipLabel, false)
150	f := genFloat01(s)
151	s.endGroup(i, false)
152
153	return f >= 1-p
154}
155
156type loadedDie struct {
157	table []int
158}
159
160func newLoadedDie(weights []int) *loadedDie {
161	assert(len(weights) > 0)
162
163	if len(weights) == 1 {
164		return &loadedDie{
165			table: []int{0},
166		}
167	}
168
169	total := 0
170	for _, w := range weights {
171		assert(w > 0 && w < 100)
172		total += w
173	}
174
175	table := make([]int, total)
176	i := 0
177	for n, w := range weights {
178		for j := i; i < j+w; i++ {
179			table[i] = n
180		}
181	}
182
183	return &loadedDie{
184		table: table,
185	}
186}
187
188func (d *loadedDie) roll(s bitStream) int {
189	i := s.beginGroup(dieRollLabel, false)
190	ix := genIndex(s, len(d.table), false)
191	s.endGroup(i, false)
192
193	return d.table[ix]
194}
195
196type repeat struct {
197	minCount   int
198	maxCount   int
199	avgCount   float64
200	pContinue  float64
201	count      int
202	group      int
203	rejected   bool
204	rejections int
205	forceStop  bool
206}
207
208func newRepeat(minCount int, maxCount int, avgCount float64) *repeat {
209	if minCount < 0 {
210		minCount = 0
211	}
212	if maxCount < 0 {
213		maxCount = maxInt
214	}
215	if avgCount < 0 {
216		avgCount = float64(minCount) + math.Min(math.Max(float64(minCount), small), (float64(maxCount)-float64(minCount))/2)
217	}
218
219	return &repeat{
220		minCount:  minCount,
221		maxCount:  maxCount,
222		avgCount:  avgCount,
223		pContinue: 1 - 1/(1+avgCount-float64(minCount)), // TODO was no -minCount intentional?
224		group:     -1,
225	}
226}
227
228func (r *repeat) avg() int {
229	return int(math.Ceil(r.avgCount))
230}
231
232func (r *repeat) more(s bitStream, label string) bool {
233	if r.group >= 0 {
234		s.endGroup(r.group, r.rejected)
235	}
236
237	r.group = s.beginGroup(label+repeatLabel, true)
238	r.rejected = false
239
240	pCont := r.pContinue
241	if r.count < r.minCount {
242		pCont = 1
243	} else if r.forceStop || r.count >= r.maxCount {
244		pCont = 0
245	}
246
247	cont := flipBiasedCoin(s, pCont)
248	if cont {
249		r.count++
250	} else {
251		s.endGroup(r.group, false)
252	}
253
254	return cont
255}
256
257func (r *repeat) reject() {
258	assert(r.count > 0)
259	r.count--
260	r.rejected = true
261	r.rejections++
262
263	if r.rejections > r.count*2 {
264		if r.count >= r.minCount {
265			r.forceStop = true
266		} else {
267			panic(invalidData("too many rejections in repeat"))
268		}
269	}
270}
271