1//  Copyright (c) 2017 Couchbase, Inc.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// 		http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15package regexp
16
17import (
18	"regexp/syntax"
19	"unicode"
20
21	unicode_utf8 "unicode/utf8"
22
23	"github.com/couchbase/vellum/utf8"
24)
25
26type compiler struct {
27	sizeLimit uint
28	insts     prog
29	instsPool []inst
30
31	sequences  utf8.Sequences
32	rangeStack utf8.RangeStack
33	startBytes []byte
34	endBytes   []byte
35}
36
37func newCompiler(sizeLimit uint) *compiler {
38	return &compiler{
39		sizeLimit:  sizeLimit,
40		startBytes: make([]byte, unicode_utf8.UTFMax),
41		endBytes:   make([]byte, unicode_utf8.UTFMax),
42	}
43}
44
45func (c *compiler) compile(ast *syntax.Regexp) (prog, error) {
46	err := c.c(ast)
47	if err != nil {
48		return nil, err
49	}
50	inst := c.allocInst()
51	inst.op = OpMatch
52	c.insts = append(c.insts, inst)
53	return c.insts, nil
54}
55
56func (c *compiler) c(ast *syntax.Regexp) (err error) {
57	if ast.Flags&syntax.NonGreedy > 1 {
58		return ErrNoLazy
59	}
60
61	switch ast.Op {
62	case syntax.OpEndLine, syntax.OpBeginLine,
63		syntax.OpBeginText, syntax.OpEndText:
64		return ErrNoEmpty
65	case syntax.OpWordBoundary, syntax.OpNoWordBoundary:
66		return ErrNoWordBoundary
67	case syntax.OpEmptyMatch:
68		return nil
69	case syntax.OpLiteral:
70		for _, r := range ast.Rune {
71			if ast.Flags&syntax.FoldCase > 0 {
72				next := syntax.Regexp{
73					Op:    syntax.OpCharClass,
74					Flags: ast.Flags & syntax.FoldCase,
75					Rune0: [2]rune{r, r},
76				}
77				next.Rune = next.Rune0[0:2]
78				return c.c(&next)
79			}
80			c.sequences, c.rangeStack, err = utf8.NewSequencesPrealloc(
81				r, r, c.sequences, c.rangeStack, c.startBytes, c.endBytes)
82			if err != nil {
83				return err
84			}
85			for _, seq := range c.sequences {
86				c.compileUtf8Ranges(seq)
87			}
88		}
89	case syntax.OpAnyChar:
90		next := syntax.Regexp{
91			Op:    syntax.OpCharClass,
92			Flags: ast.Flags & syntax.FoldCase,
93			Rune0: [2]rune{0, unicode.MaxRune},
94		}
95		next.Rune = next.Rune0[:2]
96		return c.c(&next)
97	case syntax.OpAnyCharNotNL:
98		next := syntax.Regexp{
99			Op:    syntax.OpCharClass,
100			Flags: ast.Flags & syntax.FoldCase,
101			Rune:  []rune{0, 0x09, 0x0B, unicode.MaxRune},
102		}
103		return c.c(&next)
104	case syntax.OpCharClass:
105		return c.compileClass(ast)
106	case syntax.OpCapture:
107		return c.c(ast.Sub[0])
108	case syntax.OpConcat:
109		for _, sub := range ast.Sub {
110			err := c.c(sub)
111			if err != nil {
112				return err
113			}
114		}
115		return nil
116	case syntax.OpAlternate:
117		if len(ast.Sub) == 0 {
118			return nil
119		}
120		jmpsToEnd := make([]uint, 0, len(ast.Sub)-1)
121		// does not handle last entry
122		for i := 0; i < len(ast.Sub)-1; i++ {
123			sub := ast.Sub[i]
124			split := c.emptySplit()
125			j1 := c.top()
126			err := c.c(sub)
127			if err != nil {
128				return err
129			}
130			jmpsToEnd = append(jmpsToEnd, c.emptyJump())
131			j2 := c.top()
132			c.setSplit(split, j1, j2)
133		}
134		// handle last entry
135		err := c.c(ast.Sub[len(ast.Sub)-1])
136		if err != nil {
137			return err
138		}
139		end := uint(len(c.insts))
140		for _, jmpToEnd := range jmpsToEnd {
141			c.setJump(jmpToEnd, end)
142		}
143	case syntax.OpQuest:
144		split := c.emptySplit()
145		j1 := c.top()
146		err := c.c(ast.Sub[0])
147		if err != nil {
148			return err
149		}
150		j2 := c.top()
151		c.setSplit(split, j1, j2)
152
153	case syntax.OpStar:
154		j1 := c.top()
155		split := c.emptySplit()
156		j2 := c.top()
157		err := c.c(ast.Sub[0])
158		if err != nil {
159			return err
160		}
161		jmp := c.emptyJump()
162		j3 := uint(len(c.insts))
163
164		c.setJump(jmp, j1)
165		c.setSplit(split, j2, j3)
166
167	case syntax.OpPlus:
168		j1 := c.top()
169		err := c.c(ast.Sub[0])
170		if err != nil {
171			return err
172		}
173		split := c.emptySplit()
174		j2 := c.top()
175		c.setSplit(split, j1, j2)
176
177	case syntax.OpRepeat:
178		if ast.Max == -1 {
179			for i := 0; i < ast.Min; i++ {
180				err := c.c(ast.Sub[0])
181				if err != nil {
182					return err
183				}
184			}
185			next := syntax.Regexp{
186				Op:    syntax.OpStar,
187				Flags: ast.Flags,
188				Sub:   ast.Sub,
189				Sub0:  ast.Sub0,
190				Rune:  ast.Rune,
191				Rune0: ast.Rune0,
192			}
193			return c.c(&next)
194		}
195		for i := 0; i < ast.Min; i++ {
196			err := c.c(ast.Sub[0])
197			if err != nil {
198				return err
199			}
200		}
201		splits := make([]uint, 0, ast.Max-ast.Min)
202		starts := make([]uint, 0, ast.Max-ast.Min)
203		for i := ast.Min; i < ast.Max; i++ {
204			splits = append(splits, c.emptySplit())
205			starts = append(starts, uint(len(c.insts)))
206			err := c.c(ast.Sub[0])
207			if err != nil {
208				return err
209			}
210		}
211		end := uint(len(c.insts))
212		for i := 0; i < len(splits); i++ {
213			c.setSplit(splits[i], starts[i], end)
214		}
215
216	}
217
218	return c.checkSize()
219}
220
221func (c *compiler) checkSize() error {
222	if uint(len(c.insts)*instSize) > c.sizeLimit {
223		return ErrCompiledTooBig
224	}
225	return nil
226}
227
228func (c *compiler) compileClass(ast *syntax.Regexp) error {
229	if len(ast.Rune) == 0 {
230		return nil
231	}
232	jmps := make([]uint, 0, len(ast.Rune)-2)
233	// does not do last pair
234	for i := 0; i < len(ast.Rune)-2; i += 2 {
235		rstart := ast.Rune[i]
236		rend := ast.Rune[i+1]
237
238		split := c.emptySplit()
239		j1 := c.top()
240		err := c.compileClassRange(rstart, rend)
241		if err != nil {
242			return err
243		}
244		jmps = append(jmps, c.emptyJump())
245		j2 := c.top()
246		c.setSplit(split, j1, j2)
247	}
248	// handle last pair
249	rstart := ast.Rune[len(ast.Rune)-2]
250	rend := ast.Rune[len(ast.Rune)-1]
251	err := c.compileClassRange(rstart, rend)
252	if err != nil {
253		return err
254	}
255	end := c.top()
256	for _, jmp := range jmps {
257		c.setJump(jmp, end)
258	}
259	return nil
260}
261
262func (c *compiler) compileClassRange(startR, endR rune) (err error) {
263	c.sequences, c.rangeStack, err = utf8.NewSequencesPrealloc(
264		startR, endR, c.sequences, c.rangeStack, c.startBytes, c.endBytes)
265	if err != nil {
266		return err
267	}
268	jmps := make([]uint, 0, len(c.sequences)-1)
269	// does not do last entry
270	for i := 0; i < len(c.sequences)-1; i++ {
271		seq := c.sequences[i]
272		split := c.emptySplit()
273		j1 := c.top()
274		c.compileUtf8Ranges(seq)
275		jmps = append(jmps, c.emptyJump())
276		j2 := c.top()
277		c.setSplit(split, j1, j2)
278	}
279	// handle last entry
280	c.compileUtf8Ranges(c.sequences[len(c.sequences)-1])
281	end := c.top()
282	for _, jmp := range jmps {
283		c.setJump(jmp, end)
284	}
285
286	return nil
287}
288
289func (c *compiler) compileUtf8Ranges(seq utf8.Sequence) {
290	for _, r := range seq {
291		inst := c.allocInst()
292		inst.op = OpRange
293		inst.rangeStart = r.Start
294		inst.rangeEnd = r.End
295		c.insts = append(c.insts, inst)
296	}
297}
298
299func (c *compiler) emptySplit() uint {
300	inst := c.allocInst()
301	inst.op = OpSplit
302	c.insts = append(c.insts, inst)
303	return c.top() - 1
304}
305
306func (c *compiler) emptyJump() uint {
307	inst := c.allocInst()
308	inst.op = OpJmp
309	c.insts = append(c.insts, inst)
310	return c.top() - 1
311}
312
313func (c *compiler) setSplit(i, pc1, pc2 uint) {
314	split := c.insts[i]
315	split.splitA = pc1
316	split.splitB = pc2
317}
318
319func (c *compiler) setJump(i, pc uint) {
320	jmp := c.insts[i]
321	jmp.to = pc
322}
323
324func (c *compiler) top() uint {
325	return uint(len(c.insts))
326}
327
328func (c *compiler) allocInst() *inst {
329	if len(c.instsPool) <= 0 {
330		c.instsPool = make([]inst, 16)
331	}
332	inst := &c.instsPool[0]
333	c.instsPool = c.instsPool[1:]
334	return inst
335}
336