1// Copyright 2019 Gregory Petrosyan <gregory.petrosyan@gmail.com> 2// 3// This Source Code Form is subject to the terms of the Mozilla Public 4// License, v. 2.0. If a copy of the MPL was not distributed with this 5// file, You can obtain one at https://mozilla.org/MPL/2.0/. 6 7package rapid 8 9import ( 10 "math" 11 "math/bits" 12) 13 14const ( 15 biasLabel = "bias" 16 intBitsLabel = "intbits" 17 coinFlipLabel = "coinflip" 18 dieRollLabel = "dieroll" 19 repeatLabel = "@repeat" 20) 21 22func bitmask64(n uint) uint64 { 23 return uint64(1)<<n - 1 24} 25 26func genFloat01(s bitStream) float64 { 27 return float64(s.drawBits(53)) / (1 << 53) 28} 29 30func genGeom(s bitStream, p float64) uint64 { 31 assert(p > 0 && p <= 1) 32 33 f := genFloat01(s) 34 n := math.Log1p(-f) / math.Log1p(-p) 35 36 return uint64(n) 37} 38 39func genUintNNoReject(s bitStream, max uint64) uint64 { 40 bitlen := bits.Len64(max) 41 i := s.beginGroup(intBitsLabel, false) 42 u := s.drawBits(bitlen) 43 s.endGroup(i, false) 44 if u > max { 45 u = max 46 } 47 return u 48} 49 50func genUintNUnbiased(s bitStream, max uint64) uint64 { 51 bitlen := bits.Len64(max) 52 53 for { 54 i := s.beginGroup(intBitsLabel, false) 55 u := s.drawBits(bitlen) 56 ok := u <= max 57 s.endGroup(i, !ok) 58 if ok { 59 return u 60 } 61 } 62} 63 64func genUintNBiased(s bitStream, max uint64) (uint64, bool, bool) { 65 bitlen := bits.Len64(max) 66 i := s.beginGroup(biasLabel, false) 67 m := math.Max(8, (float64(bitlen)+48)/7) 68 n := genGeom(s, 1/(m+1)) + 1 69 s.endGroup(i, false) 70 71 if int(n) < bitlen { 72 bitlen = int(n) 73 } else if int(n) >= 64-(16-int(m))*4 { 74 bitlen = 65 75 } 76 77 for { 78 i := s.beginGroup(intBitsLabel, false) 79 u := s.drawBits(bitlen) 80 ok := bitlen > 64 || u <= max 81 s.endGroup(i, !ok) 82 if bitlen > 64 { 83 u = max 84 } 85 if u <= max { 86 return u, u == 0 && n == 1, u == max && bitlen >= int(n) 87 } 88 } 89} 90 91func genUintN(s bitStream, max uint64, bias bool) (uint64, bool, bool) { 92 if bias { 93 return genUintNBiased(s, max) 94 } else { 95 return genUintNUnbiased(s, max), false, false 96 } 97} 98 99func genUintRange(s bitStream, min uint64, max uint64, bias bool) (uint64, bool, bool) { 100 assertf(min <= max, "invalid range [%v, %v]", min, max) 101 102 u, lOverflow, rOverflow := genUintN(s, max-min, bias) 103 104 return min + u, lOverflow, rOverflow 105} 106 107func genIntRange(s bitStream, min int64, max int64, bias bool) (int64, bool, bool) { 108 assertf(min <= max, "invalid range [%v, %v]", min, max) 109 110 var posMin, negMin uint64 111 var pNeg float64 112 if min >= 0 { 113 posMin = uint64(min) 114 pNeg = 0 115 } else if max <= 0 { 116 negMin = uint64(-max) 117 pNeg = 1 118 } else { 119 posMin = 0 120 negMin = 1 121 pos := uint64(max) + 1 122 neg := uint64(-min) 123 pNeg = float64(neg) / (float64(neg) + float64(pos)) 124 if bias { 125 pNeg = 0.5 126 } 127 } 128 129 if flipBiasedCoin(s, pNeg) { 130 u, lOverflow, rOverflow := genUintRange(s, negMin, uint64(-min), bias) 131 return -int64(u), rOverflow, lOverflow && max <= 0 132 } else { 133 u, lOverflow, rOverflow := genUintRange(s, posMin, uint64(max), bias) 134 return int64(u), lOverflow && min >= 0, rOverflow 135 } 136} 137 138func genIndex(s bitStream, n int, bias bool) int { 139 assert(n > 0) 140 141 u, _, _ := genUintN(s, uint64(n-1), bias) 142 143 return int(u) 144} 145 146func flipBiasedCoin(s bitStream, p float64) bool { 147 assert(p >= 0 && p <= 1) 148 149 i := s.beginGroup(coinFlipLabel, false) 150 f := genFloat01(s) 151 s.endGroup(i, false) 152 153 return f >= 1-p 154} 155 156type loadedDie struct { 157 table []int 158} 159 160func newLoadedDie(weights []int) *loadedDie { 161 assert(len(weights) > 0) 162 163 if len(weights) == 1 { 164 return &loadedDie{ 165 table: []int{0}, 166 } 167 } 168 169 total := 0 170 for _, w := range weights { 171 assert(w > 0 && w < 100) 172 total += w 173 } 174 175 table := make([]int, total) 176 i := 0 177 for n, w := range weights { 178 for j := i; i < j+w; i++ { 179 table[i] = n 180 } 181 } 182 183 return &loadedDie{ 184 table: table, 185 } 186} 187 188func (d *loadedDie) roll(s bitStream) int { 189 i := s.beginGroup(dieRollLabel, false) 190 ix := genIndex(s, len(d.table), false) 191 s.endGroup(i, false) 192 193 return d.table[ix] 194} 195 196type repeat struct { 197 minCount int 198 maxCount int 199 avgCount float64 200 pContinue float64 201 count int 202 group int 203 rejected bool 204 rejections int 205 forceStop bool 206} 207 208func newRepeat(minCount int, maxCount int, avgCount float64) *repeat { 209 if minCount < 0 { 210 minCount = 0 211 } 212 if maxCount < 0 { 213 maxCount = maxInt 214 } 215 if avgCount < 0 { 216 avgCount = float64(minCount) + math.Min(math.Max(float64(minCount), small), (float64(maxCount)-float64(minCount))/2) 217 } 218 219 return &repeat{ 220 minCount: minCount, 221 maxCount: maxCount, 222 avgCount: avgCount, 223 pContinue: 1 - 1/(1+avgCount-float64(minCount)), // TODO was no -minCount intentional? 224 group: -1, 225 } 226} 227 228func (r *repeat) avg() int { 229 return int(math.Ceil(r.avgCount)) 230} 231 232func (r *repeat) more(s bitStream, label string) bool { 233 if r.group >= 0 { 234 s.endGroup(r.group, r.rejected) 235 } 236 237 r.group = s.beginGroup(label+repeatLabel, true) 238 r.rejected = false 239 240 pCont := r.pContinue 241 if r.count < r.minCount { 242 pCont = 1 243 } else if r.forceStop || r.count >= r.maxCount { 244 pCont = 0 245 } 246 247 cont := flipBiasedCoin(s, pCont) 248 if cont { 249 r.count++ 250 } else { 251 s.endGroup(r.group, false) 252 } 253 254 return cont 255} 256 257func (r *repeat) reject() { 258 assert(r.count > 0) 259 r.count-- 260 r.rejected = true 261 r.rejections++ 262 263 if r.rejections > r.count*2 { 264 if r.count >= r.minCount { 265 r.forceStop = true 266 } else { 267 panic(invalidData("too many rejections in repeat")) 268 } 269 } 270} 271