1// Copyright 2019 Gregory Petrosyan <gregory.petrosyan@gmail.com>
2//
3// This Source Code Form is subject to the terms of the Mozilla Public
4// License, v. 2.0. If a copy of the MPL was not distributed with this
5// file, You can obtain one at https://mozilla.org/MPL/2.0/.
6
7package rapid
8
9import (
10	"bytes"
11	"fmt"
12	"reflect"
13	"regexp"
14	"regexp/syntax"
15	"strings"
16	"sync"
17	"unicode"
18	"unicode/utf8"
19)
20
21var (
22	stringType    = reflect.TypeOf("")
23	byteSliceType = reflect.TypeOf([]byte(nil))
24
25	defaultRunes = []rune{
26		'?',
27		'~', '!', '@', '#', '$', '%', '^', '&', '*', '_', '-', '+', '=',
28		'.', ',', ':', ';',
29		' ', '\t', '\r', '\n',
30		'/', '\\', '|',
31		'(', '[', '{', '<',
32		'\'', '"', '`',
33		'\x00', '\x0B', '\x1B', '\x7F', // NUL, VT, ESC, DEL
34		'\uFEFF', '\uFFFD', '\u202E', // BOM, replacement character, RTL override
35		'Ⱥ', // In UTF-8, Ⱥ increases in length from 2 to 3 bytes when lowercased
36	}
37
38	// unicode.Categories without surrogates (which are not allowed in UTF-8), ordered by taste
39	defaultTables = []*unicode.RangeTable{
40		unicode.Lu, // Letter, uppercase        (1781)
41		unicode.Ll, // Letter, lowercase        (2145)
42		unicode.Lt, // Letter, titlecase          (31)
43		unicode.Lm, // Letter, modifier          (250)
44		unicode.Lo, // Letter, other          (121212)
45		unicode.Nd, // Number, decimal digit     (610)
46		unicode.Nl, // Number, letter            (236)
47		unicode.No, // Number, other             (807)
48		unicode.P,  // Punctuation               (788)
49		unicode.Sm, // Symbol, math              (948)
50		unicode.Sc, // Symbol, currency           (57)
51		unicode.Sk, // Symbol, modifier          (121)
52		unicode.So, // Symbol, other            (5984)
53		unicode.Mn, // Mark, nonspacing         (1805)
54		unicode.Me, // Mark, enclosing            (13)
55		unicode.Mc, // Mark, spacing combining   (415)
56		unicode.Z,  // Separator                  (19)
57		unicode.Cc, // Other, control             (65)
58		unicode.Cf, // Other, format             (152)
59		unicode.Co, // Other, private use     (137468)
60	}
61
62	expandedTables  = sync.Map{} // *unicode.RangeTable / regexp name -> []rune
63	compiledRegexps = sync.Map{} // regexp -> compiledRegexp
64	regexpNames     = sync.Map{} // *regexp.Regexp -> string
65	charClassGens   = sync.Map{} // regexp name -> *Generator
66
67	anyRuneGen     = Rune()
68	anyRuneGenNoNL = Rune().Filter(func(r rune) bool { return r != '\n' })
69)
70
71type compiledRegexp struct {
72	syn *syntax.Regexp
73	re  *regexp.Regexp
74}
75
76func Rune() *Generator {
77	return runesFrom(true, defaultRunes, defaultTables...)
78}
79
80func RuneFrom(runes []rune, tables ...*unicode.RangeTable) *Generator {
81	return runesFrom(false, runes, tables...)
82}
83
84func runesFrom(default_ bool, runes []rune, tables ...*unicode.RangeTable) *Generator {
85	if len(tables) == 0 {
86		assertf(len(runes) > 0, "at least one rune should be specified")
87	}
88	if len(runes) == 0 {
89		assertf(len(tables) > 0, "at least one *unicode.RangeTable should be specified")
90	}
91
92	var weights []int
93	if len(runes) > 0 {
94		weights = append(weights, len(tables))
95	}
96	for range tables {
97		weights = append(weights, 1)
98	}
99
100	tables_ := make([][]rune, len(tables))
101	for i := range tables {
102		tables_[i] = expandRangeTable(tables[i], tables[i])
103		assertf(len(tables_[i]) > 0, "empty *unicode.RangeTable %v", i)
104	}
105
106	return newGenerator(&runeGen{
107		die:      newLoadedDie(weights),
108		runes:    runes,
109		tables:   tables_,
110		default_: default_,
111	})
112}
113
114type runeGen struct {
115	die      *loadedDie
116	runes    []rune
117	tables   [][]rune
118	default_ bool
119}
120
121func (g *runeGen) String() string {
122	if g.default_ {
123		return "Rune()"
124	} else {
125		return fmt.Sprintf("Rune(%v runes, %v tables)", len(g.runes), len(g.tables))
126	}
127}
128
129func (g *runeGen) type_() reflect.Type {
130	return int32Type
131}
132
133func (g *runeGen) value(t *T) value {
134	n := g.die.roll(t.s)
135
136	runes := g.runes
137	if len(g.runes) == 0 {
138		runes = g.tables[n]
139	} else if n > 0 {
140		runes = g.tables[n-1]
141	}
142
143	return runes[genIndex(t.s, len(runes), true)]
144}
145
146func String() *Generator {
147	return StringOf(anyRuneGen)
148}
149
150func StringN(minRunes int, maxRunes int, maxLen int) *Generator {
151	return StringOfN(anyRuneGen, minRunes, maxRunes, maxLen)
152}
153
154func StringOf(elem *Generator) *Generator {
155	return StringOfN(elem, -1, -1, -1)
156}
157
158func StringOfN(elem *Generator, minElems int, maxElems int, maxLen int) *Generator {
159	assertValidRange(minElems, maxElems)
160	assertf(elem.type_() == int32Type || elem.type_() == uint8Type, "element generator should generate runes or bytes, not %v", elem.type_())
161	assertf(maxLen < 0 || maxLen >= maxElems, "maximum length (%v) should not be less than maximum number of elements (%v)", maxLen, maxElems)
162
163	return newGenerator(&stringGen{
164		elem:     elem,
165		minElems: minElems,
166		maxElems: maxElems,
167		maxLen:   maxLen,
168	})
169}
170
171type stringGen struct {
172	elem     *Generator
173	minElems int
174	maxElems int
175	maxLen   int
176}
177
178func (g *stringGen) String() string {
179	if g.elem == anyRuneGen {
180		if g.minElems < 0 && g.maxElems < 0 && g.maxLen < 0 {
181			return "String()"
182		} else {
183			return fmt.Sprintf("StringN(minRunes=%v, maxRunes=%v, maxLen=%v)", g.minElems, g.maxElems, g.maxLen)
184		}
185	} else {
186		if g.minElems < 0 && g.maxElems < 0 && g.maxLen < 0 {
187			return fmt.Sprintf("StringOf(%v)", g.elem)
188		} else {
189			return fmt.Sprintf("StringOfN(%v, minElems=%v, maxElems=%v, maxLen=%v)", g.elem, g.minElems, g.maxElems, g.maxLen)
190		}
191	}
192}
193
194func (g *stringGen) type_() reflect.Type {
195	return stringType
196}
197
198func (g *stringGen) value(t *T) value {
199	repeat := newRepeat(g.minElems, g.maxElems, -1)
200
201	var b strings.Builder
202	b.Grow(repeat.avg())
203
204	if g.elem.type_() == int32Type {
205		maxLen := g.maxLen
206		if maxLen < 0 {
207			maxLen = maxInt
208		}
209
210		for repeat.more(t.s, g.elem.String()) {
211			r := g.elem.value(t).(rune)
212			n := utf8.RuneLen(r)
213
214			if n < 0 || b.Len()+n > maxLen {
215				repeat.reject()
216			} else {
217				b.WriteRune(r)
218			}
219		}
220	} else {
221		for repeat.more(t.s, g.elem.String()) {
222			b.WriteByte(g.elem.value(t).(byte))
223		}
224	}
225
226	return b.String()
227}
228
229func StringMatching(expr string) *Generator {
230	return matching(expr, true)
231}
232
233func SliceOfBytesMatching(expr string) *Generator {
234	return matching(expr, false)
235}
236
237func matching(expr string, str bool) *Generator {
238	compiled, err := compileRegexp(expr)
239	assertf(err == nil, "%v", err)
240
241	return newGenerator(&regexpGen{
242		str:  str,
243		expr: expr,
244		syn:  compiled.syn,
245		re:   compiled.re,
246	})
247}
248
249type runeWriter interface {
250	WriteRune(r rune) (int, error)
251}
252
253type regexpGen struct {
254	str  bool
255	expr string
256	syn  *syntax.Regexp
257	re   *regexp.Regexp
258}
259
260func (g *regexpGen) String() string {
261	if g.str {
262		return fmt.Sprintf("StringMatching(%q)", g.expr)
263	} else {
264		return fmt.Sprintf("SliceOfBytesMatching(%q)", g.expr)
265	}
266}
267
268func (g *regexpGen) type_() reflect.Type {
269	if g.str {
270		return stringType
271	} else {
272		return byteSliceType
273	}
274}
275
276func (g *regexpGen) maybeString(t *T) value {
277	b := &strings.Builder{}
278	g.build(b, g.syn, t)
279	v := b.String()
280
281	if g.re.MatchString(v) {
282		return v
283	} else {
284		return nil
285	}
286}
287
288func (g *regexpGen) maybeSlice(t *T) value {
289	b := &bytes.Buffer{}
290	g.build(b, g.syn, t)
291	v := b.Bytes()
292
293	if g.re.Match(v) {
294		return v
295	} else {
296		return nil
297	}
298}
299
300func (g *regexpGen) value(t *T) value {
301	if g.str {
302		return find(g.maybeString, t, small)
303	} else {
304		return find(g.maybeSlice, t, small)
305	}
306}
307
308func (g *regexpGen) build(w runeWriter, re *syntax.Regexp, t *T) {
309	i := t.s.beginGroup(re.Op.String(), false)
310
311	switch re.Op {
312	case syntax.OpNoMatch:
313		panic(invalidData("no possible regexp match"))
314	case syntax.OpEmptyMatch:
315		t.s.drawBits(0)
316	case syntax.OpLiteral:
317		t.s.drawBits(0)
318		for _, r := range re.Rune {
319			_, _ = w.WriteRune(maybeFoldCase(t.s, r, re.Flags))
320		}
321	case syntax.OpCharClass, syntax.OpAnyCharNotNL, syntax.OpAnyChar:
322		sub := anyRuneGen
323		switch re.Op {
324		case syntax.OpCharClass:
325			sub = charClassGen(re)
326		case syntax.OpAnyCharNotNL:
327			sub = anyRuneGenNoNL
328		}
329		r := sub.value(t).(rune)
330		_, _ = w.WriteRune(maybeFoldCase(t.s, r, re.Flags))
331	case syntax.OpBeginLine, syntax.OpEndLine,
332		syntax.OpBeginText, syntax.OpEndText,
333		syntax.OpWordBoundary, syntax.OpNoWordBoundary:
334		t.s.drawBits(0) // do nothing and hope that Assume() is enough
335	case syntax.OpCapture:
336		g.build(w, re.Sub[0], t)
337	case syntax.OpStar, syntax.OpPlus, syntax.OpQuest, syntax.OpRepeat:
338		min, max := re.Min, re.Max
339		switch re.Op {
340		case syntax.OpStar:
341			min, max = 0, -1
342		case syntax.OpPlus:
343			min, max = 1, -1
344		case syntax.OpQuest:
345			min, max = 0, 1
346		}
347		repeat := newRepeat(min, max, -1)
348		for repeat.more(t.s, regexpName(re.Sub[0])) {
349			g.build(w, re.Sub[0], t)
350		}
351	case syntax.OpConcat:
352		for _, sub := range re.Sub {
353			g.build(w, sub, t)
354		}
355	case syntax.OpAlternate:
356		ix := genIndex(t.s, len(re.Sub), true)
357		g.build(w, re.Sub[ix], t)
358	default:
359		assertf(false, "invalid regexp op %v", re.Op)
360	}
361
362	t.s.endGroup(i, false)
363}
364
365func maybeFoldCase(s bitStream, r rune, flags syntax.Flags) rune {
366	n := uint64(0)
367	if flags&syntax.FoldCase != 0 {
368		n, _, _ = genUintN(s, 4, false)
369	}
370
371	for i := 0; i < int(n); i++ {
372		r = unicode.SimpleFold(r)
373	}
374
375	return r
376}
377
378func expandRangeTable(t *unicode.RangeTable, key interface{}) []rune {
379	cached, ok := expandedTables.Load(key)
380	if ok {
381		return cached.([]rune)
382	}
383
384	var ret []rune
385	for _, r := range t.R16 {
386		for i := r.Lo; i <= r.Hi; i += r.Stride {
387			ret = append(ret, rune(i))
388		}
389	}
390	for _, r := range t.R32 {
391		for i := r.Lo; i <= r.Hi; i += r.Stride {
392			ret = append(ret, rune(i))
393		}
394	}
395	expandedTables.Store(key, ret)
396
397	return ret
398}
399
400func compileRegexp(expr string) (compiledRegexp, error) {
401	cached, ok := compiledRegexps.Load(expr)
402	if ok {
403		return cached.(compiledRegexp), nil
404	}
405
406	syn, err := syntax.Parse(expr, syntax.Perl)
407	if err != nil {
408		return compiledRegexp{}, fmt.Errorf("failed to parse regexp %q: %v", expr, err)
409	}
410
411	re, err := regexp.Compile(expr)
412	if err != nil {
413		return compiledRegexp{}, fmt.Errorf("failed to compile regexp %q: %v", expr, err)
414	}
415
416	ret := compiledRegexp{syn, re}
417	compiledRegexps.Store(expr, ret)
418
419	return ret, nil
420}
421
422func regexpName(re *syntax.Regexp) string {
423	cached, ok := regexpNames.Load(re)
424	if ok {
425		return cached.(string)
426	}
427
428	s := re.String()
429	regexpNames.Store(re, s)
430
431	return s
432}
433
434func charClassGen(re *syntax.Regexp) *Generator {
435	cached, ok := charClassGens.Load(regexpName(re))
436	if ok {
437		return cached.(*Generator)
438	}
439
440	t := &unicode.RangeTable{}
441	for i := 0; i < len(re.Rune); i += 2 {
442		t.R32 = append(t.R32, unicode.Range32{
443			Lo:     uint32(re.Rune[i]),
444			Hi:     uint32(re.Rune[i+1]),
445			Stride: 1,
446		})
447	}
448
449	g := newGenerator(&runeGen{
450		die:    newLoadedDie([]int{1}),
451		tables: [][]rune{expandRangeTable(t, regexpName(re))},
452	})
453	charClassGens.Store(regexpName(re), g)
454
455	return g
456}
457