1//  Copyright (c) 2017 Couchbase, Inc.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// 		http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15package regexp
16
17import (
18	"regexp/syntax"
19	"unicode"
20
21	unicode_utf8 "unicode/utf8"
22
23	"github.com/couchbase/vellum/utf8"
24)
25
26type compiler struct {
27	sizeLimit uint
28	insts     prog
29	instsPool []inst
30
31	sequences  utf8.Sequences
32	rangeStack utf8.RangeStack
33	startBytes []byte
34	endBytes   []byte
35}
36
37func newCompiler(sizeLimit uint) *compiler {
38	return &compiler{
39		sizeLimit:  sizeLimit,
40		startBytes: make([]byte, unicode_utf8.UTFMax),
41		endBytes:   make([]byte, unicode_utf8.UTFMax),
42	}
43}
44
45func (c *compiler) compile(ast *syntax.Regexp) (prog, error) {
46	err := c.c(ast)
47	if err != nil {
48		return nil, err
49	}
50	inst := c.allocInst()
51	inst.op = OpMatch
52	c.insts = append(c.insts, inst)
53	return c.insts, nil
54}
55
56func (c *compiler) c(ast *syntax.Regexp) (err error) {
57	if ast.Flags&syntax.NonGreedy > 1 {
58		return ErrNoLazy
59	}
60
61	switch ast.Op {
62	case syntax.OpEndLine, syntax.OpBeginLine,
63		syntax.OpBeginText, syntax.OpEndText:
64		return ErrNoEmpty
65	case syntax.OpWordBoundary, syntax.OpNoWordBoundary:
66		return ErrNoWordBoundary
67	case syntax.OpEmptyMatch:
68		return nil
69	case syntax.OpLiteral:
70		for _, r := range ast.Rune {
71			if ast.Flags&syntax.FoldCase > 0 {
72				next := syntax.Regexp{
73					Op:    syntax.OpCharClass,
74					Flags: ast.Flags & syntax.FoldCase,
75					Rune0: [2]rune{r, r},
76				}
77				next.Rune = next.Rune0[0:2]
78				// try to find more folded runes
79				for r1 := unicode.SimpleFold(r); r1 != r; r1 = unicode.SimpleFold(r1) {
80					next.Rune = append(next.Rune, r1, r1)
81				}
82				err = c.c(&next)
83				if err != nil {
84					return err
85				}
86			} else {
87				c.sequences, c.rangeStack, err = utf8.NewSequencesPrealloc(
88					r, r, c.sequences, c.rangeStack, c.startBytes, c.endBytes)
89				if err != nil {
90					return err
91				}
92				for _, seq := range c.sequences {
93					c.compileUtf8Ranges(seq)
94				}
95			}
96		}
97	case syntax.OpAnyChar:
98		next := syntax.Regexp{
99			Op:    syntax.OpCharClass,
100			Flags: ast.Flags & syntax.FoldCase,
101			Rune0: [2]rune{0, unicode.MaxRune},
102		}
103		next.Rune = next.Rune0[:2]
104		return c.c(&next)
105	case syntax.OpAnyCharNotNL:
106		next := syntax.Regexp{
107			Op:    syntax.OpCharClass,
108			Flags: ast.Flags & syntax.FoldCase,
109			Rune:  []rune{0, 0x09, 0x0B, unicode.MaxRune},
110		}
111		return c.c(&next)
112	case syntax.OpCharClass:
113		return c.compileClass(ast)
114	case syntax.OpCapture:
115		return c.c(ast.Sub[0])
116	case syntax.OpConcat:
117		for _, sub := range ast.Sub {
118			err := c.c(sub)
119			if err != nil {
120				return err
121			}
122		}
123		return nil
124	case syntax.OpAlternate:
125		if len(ast.Sub) == 0 {
126			return nil
127		}
128		jmpsToEnd := make([]uint, 0, len(ast.Sub)-1)
129		// does not handle last entry
130		for i := 0; i < len(ast.Sub)-1; i++ {
131			sub := ast.Sub[i]
132			split := c.emptySplit()
133			j1 := c.top()
134			err := c.c(sub)
135			if err != nil {
136				return err
137			}
138			jmpsToEnd = append(jmpsToEnd, c.emptyJump())
139			j2 := c.top()
140			c.setSplit(split, j1, j2)
141		}
142		// handle last entry
143		err := c.c(ast.Sub[len(ast.Sub)-1])
144		if err != nil {
145			return err
146		}
147		end := uint(len(c.insts))
148		for _, jmpToEnd := range jmpsToEnd {
149			c.setJump(jmpToEnd, end)
150		}
151	case syntax.OpQuest:
152		split := c.emptySplit()
153		j1 := c.top()
154		err := c.c(ast.Sub[0])
155		if err != nil {
156			return err
157		}
158		j2 := c.top()
159		c.setSplit(split, j1, j2)
160
161	case syntax.OpStar:
162		j1 := c.top()
163		split := c.emptySplit()
164		j2 := c.top()
165		err := c.c(ast.Sub[0])
166		if err != nil {
167			return err
168		}
169		jmp := c.emptyJump()
170		j3 := uint(len(c.insts))
171
172		c.setJump(jmp, j1)
173		c.setSplit(split, j2, j3)
174
175	case syntax.OpPlus:
176		j1 := c.top()
177		err := c.c(ast.Sub[0])
178		if err != nil {
179			return err
180		}
181		split := c.emptySplit()
182		j2 := c.top()
183		c.setSplit(split, j1, j2)
184
185	case syntax.OpRepeat:
186		if ast.Max == -1 {
187			for i := 0; i < ast.Min; i++ {
188				err := c.c(ast.Sub[0])
189				if err != nil {
190					return err
191				}
192			}
193			next := syntax.Regexp{
194				Op:    syntax.OpStar,
195				Flags: ast.Flags,
196				Sub:   ast.Sub,
197				Sub0:  ast.Sub0,
198				Rune:  ast.Rune,
199				Rune0: ast.Rune0,
200			}
201			return c.c(&next)
202		}
203		for i := 0; i < ast.Min; i++ {
204			err := c.c(ast.Sub[0])
205			if err != nil {
206				return err
207			}
208		}
209		splits := make([]uint, 0, ast.Max-ast.Min)
210		starts := make([]uint, 0, ast.Max-ast.Min)
211		for i := ast.Min; i < ast.Max; i++ {
212			splits = append(splits, c.emptySplit())
213			starts = append(starts, uint(len(c.insts)))
214			err := c.c(ast.Sub[0])
215			if err != nil {
216				return err
217			}
218		}
219		end := uint(len(c.insts))
220		for i := 0; i < len(splits); i++ {
221			c.setSplit(splits[i], starts[i], end)
222		}
223
224	}
225
226	return c.checkSize()
227}
228
229func (c *compiler) checkSize() error {
230	if uint(len(c.insts)*instSize) > c.sizeLimit {
231		return ErrCompiledTooBig
232	}
233	return nil
234}
235
236func (c *compiler) compileClass(ast *syntax.Regexp) error {
237	if len(ast.Rune) == 0 {
238		return nil
239	}
240	jmps := make([]uint, 0, len(ast.Rune)-2)
241	// does not do last pair
242	for i := 0; i < len(ast.Rune)-2; i += 2 {
243		rstart := ast.Rune[i]
244		rend := ast.Rune[i+1]
245
246		split := c.emptySplit()
247		j1 := c.top()
248		err := c.compileClassRange(rstart, rend)
249		if err != nil {
250			return err
251		}
252		jmps = append(jmps, c.emptyJump())
253		j2 := c.top()
254		c.setSplit(split, j1, j2)
255	}
256	// handle last pair
257	rstart := ast.Rune[len(ast.Rune)-2]
258	rend := ast.Rune[len(ast.Rune)-1]
259	err := c.compileClassRange(rstart, rend)
260	if err != nil {
261		return err
262	}
263	end := c.top()
264	for _, jmp := range jmps {
265		c.setJump(jmp, end)
266	}
267	return nil
268}
269
270func (c *compiler) compileClassRange(startR, endR rune) (err error) {
271	c.sequences, c.rangeStack, err = utf8.NewSequencesPrealloc(
272		startR, endR, c.sequences, c.rangeStack, c.startBytes, c.endBytes)
273	if err != nil {
274		return err
275	}
276	jmps := make([]uint, 0, len(c.sequences)-1)
277	// does not do last entry
278	for i := 0; i < len(c.sequences)-1; i++ {
279		seq := c.sequences[i]
280		split := c.emptySplit()
281		j1 := c.top()
282		c.compileUtf8Ranges(seq)
283		jmps = append(jmps, c.emptyJump())
284		j2 := c.top()
285		c.setSplit(split, j1, j2)
286	}
287	// handle last entry
288	c.compileUtf8Ranges(c.sequences[len(c.sequences)-1])
289	end := c.top()
290	for _, jmp := range jmps {
291		c.setJump(jmp, end)
292	}
293
294	return nil
295}
296
297func (c *compiler) compileUtf8Ranges(seq utf8.Sequence) {
298	for _, r := range seq {
299		inst := c.allocInst()
300		inst.op = OpRange
301		inst.rangeStart = r.Start
302		inst.rangeEnd = r.End
303		c.insts = append(c.insts, inst)
304	}
305}
306
307func (c *compiler) emptySplit() uint {
308	inst := c.allocInst()
309	inst.op = OpSplit
310	c.insts = append(c.insts, inst)
311	return c.top() - 1
312}
313
314func (c *compiler) emptyJump() uint {
315	inst := c.allocInst()
316	inst.op = OpJmp
317	c.insts = append(c.insts, inst)
318	return c.top() - 1
319}
320
321func (c *compiler) setSplit(i, pc1, pc2 uint) {
322	split := c.insts[i]
323	split.splitA = pc1
324	split.splitB = pc2
325}
326
327func (c *compiler) setJump(i, pc uint) {
328	jmp := c.insts[i]
329	jmp.to = pc
330}
331
332func (c *compiler) top() uint {
333	return uint(len(c.insts))
334}
335
336func (c *compiler) allocInst() *inst {
337	if len(c.instsPool) <= 0 {
338		c.instsPool = make([]inst, 16)
339	}
340	inst := &c.instsPool[0]
341	c.instsPool = c.instsPool[1:]
342	return inst
343}
344