1package xpath
2
3import (
4	"errors"
5	"fmt"
6)
7
8type flag int
9
10const (
11	noneFlag flag = iota
12	filterFlag
13)
14
15// builder provides building an XPath expressions.
16type builder struct {
17	depth      int
18	flag       flag
19	firstInput query
20}
21
22// axisPredicate creates a predicate to predicating for this axis node.
23func axisPredicate(root *axisNode) func(NodeNavigator) bool {
24	// get current axix node type.
25	typ := ElementNode
26	switch root.AxeType {
27	case "attribute":
28		typ = AttributeNode
29	case "self", "parent":
30		typ = allNode
31	default:
32		switch root.Prop {
33		case "comment":
34			typ = CommentNode
35		case "text":
36			typ = TextNode
37			//	case "processing-instruction":
38		//	typ = ProcessingInstructionNode
39		case "node":
40			typ = allNode
41		}
42	}
43	nametest := root.LocalName != "" || root.Prefix != ""
44	predicate := func(n NodeNavigator) bool {
45		if typ == n.NodeType() || typ == allNode || typ == TextNode {
46			if nametest {
47				if root.LocalName == n.LocalName() && root.Prefix == n.Prefix() {
48					return true
49				}
50			} else {
51				return true
52			}
53		}
54		return false
55	}
56
57	return predicate
58}
59
60// processAxisNode processes a query for the XPath axis node.
61func (b *builder) processAxisNode(root *axisNode) (query, error) {
62	var (
63		err       error
64		qyInput   query
65		qyOutput  query
66		predicate = axisPredicate(root)
67	)
68
69	if root.Input == nil {
70		qyInput = &contextQuery{}
71	} else {
72		if root.AxeType == "child" && (root.Input.Type() == nodeAxis) {
73			if input := root.Input.(*axisNode); input.AxeType == "descendant-or-self" {
74				var qyGrandInput query
75				if input.Input != nil {
76					qyGrandInput, _ = b.processNode(input.Input)
77				} else {
78					qyGrandInput = &contextQuery{}
79				}
80				qyOutput = &descendantQuery{Input: qyGrandInput, Predicate: predicate, Self: true}
81				return qyOutput, nil
82			}
83		}
84		qyInput, err = b.processNode(root.Input)
85		if err != nil {
86			return nil, err
87		}
88	}
89
90	switch root.AxeType {
91	case "ancestor":
92		qyOutput = &ancestorQuery{Input: qyInput, Predicate: predicate}
93	case "ancestor-or-self":
94		qyOutput = &ancestorQuery{Input: qyInput, Predicate: predicate, Self: true}
95	case "attribute":
96		qyOutput = &attributeQuery{Input: qyInput, Predicate: predicate}
97	case "child":
98		filter := func(n NodeNavigator) bool {
99			v := predicate(n)
100			switch root.Prop {
101			case "text":
102				v = v && n.NodeType() == TextNode
103			case "node":
104				v = v && (n.NodeType() == ElementNode || n.NodeType() == TextNode)
105			case "comment":
106				v = v && n.NodeType() == CommentNode
107			}
108			return v
109		}
110		qyOutput = &childQuery{Input: qyInput, Predicate: filter}
111	case "descendant":
112		qyOutput = &descendantQuery{Input: qyInput, Predicate: predicate}
113	case "descendant-or-self":
114		qyOutput = &descendantQuery{Input: qyInput, Predicate: predicate, Self: true}
115	case "following":
116		qyOutput = &followingQuery{Input: qyInput, Predicate: predicate}
117	case "following-sibling":
118		qyOutput = &followingQuery{Input: qyInput, Predicate: predicate, Sibling: true}
119	case "parent":
120		qyOutput = &parentQuery{Input: qyInput, Predicate: predicate}
121	case "preceding":
122		qyOutput = &precedingQuery{Input: qyInput, Predicate: predicate}
123	case "preceding-sibling":
124		qyOutput = &precedingQuery{Input: qyInput, Predicate: predicate, Sibling: true}
125	case "self":
126		qyOutput = &selfQuery{Input: qyInput, Predicate: predicate}
127	case "namespace":
128		// haha,what will you do someting??
129	default:
130		err = fmt.Errorf("unknown axe type: %s", root.AxeType)
131		return nil, err
132	}
133	return qyOutput, nil
134}
135
136// processFilterNode builds query for the XPath filter predicate.
137func (b *builder) processFilterNode(root *filterNode) (query, error) {
138	b.flag |= filterFlag
139
140	qyInput, err := b.processNode(root.Input)
141	if err != nil {
142		return nil, err
143	}
144	qyCond, err := b.processNode(root.Condition)
145	if err != nil {
146		return nil, err
147	}
148	qyOutput := &filterQuery{Input: qyInput, Predicate: qyCond}
149	return qyOutput, nil
150}
151
152// processFunctionNode processes query for the XPath function node.
153func (b *builder) processFunctionNode(root *functionNode) (query, error) {
154	var qyOutput query
155	switch root.FuncName {
156	case "starts-with":
157		arg1, err := b.processNode(root.Args[0])
158		if err != nil {
159			return nil, err
160		}
161		arg2, err := b.processNode(root.Args[1])
162		if err != nil {
163			return nil, err
164		}
165		qyOutput = &functionQuery{Input: b.firstInput, Func: startwithFunc(arg1, arg2)}
166	case "ends-with":
167		arg1, err := b.processNode(root.Args[0])
168		if err != nil {
169			return nil, err
170		}
171		arg2, err := b.processNode(root.Args[1])
172		if err != nil {
173			return nil, err
174		}
175		qyOutput = &functionQuery{Input: b.firstInput, Func: endwithFunc(arg1, arg2)}
176	case "contains":
177		arg1, err := b.processNode(root.Args[0])
178		if err != nil {
179			return nil, err
180		}
181		arg2, err := b.processNode(root.Args[1])
182		if err != nil {
183			return nil, err
184		}
185
186		qyOutput = &functionQuery{Input: b.firstInput, Func: containsFunc(arg1, arg2)}
187	case "substring":
188		//substring( string , start [, length] )
189		if len(root.Args) < 2 {
190			return nil, errors.New("xpath: substring function must have at least two parameter")
191		}
192		var (
193			arg1, arg2, arg3 query
194			err              error
195		)
196		if arg1, err = b.processNode(root.Args[0]); err != nil {
197			return nil, err
198		}
199		if arg2, err = b.processNode(root.Args[1]); err != nil {
200			return nil, err
201		}
202		if len(root.Args) == 3 {
203			if arg3, err = b.processNode(root.Args[2]); err != nil {
204				return nil, err
205			}
206		}
207		qyOutput = &functionQuery{Input: b.firstInput, Func: substringFunc(arg1, arg2, arg3)}
208	case "substring-before", "substring-after":
209		//substring-xxxx( haystack, needle )
210		if len(root.Args) != 2 {
211			return nil, errors.New("xpath: substring-before function must have two parameters")
212		}
213		var (
214			arg1, arg2 query
215			err        error
216		)
217		if arg1, err = b.processNode(root.Args[0]); err != nil {
218			return nil, err
219		}
220		if arg2, err = b.processNode(root.Args[1]); err != nil {
221			return nil, err
222		}
223		qyOutput = &functionQuery{
224			Input: b.firstInput,
225			Func:  substringIndFunc(arg1, arg2, root.FuncName == "substring-after"),
226		}
227	case "string-length":
228		// string-length( [string] )
229		if len(root.Args) < 1 {
230			return nil, errors.New("xpath: string-length function must have at least one parameter")
231		}
232		arg1, err := b.processNode(root.Args[0])
233		if err != nil {
234			return nil, err
235		}
236		qyOutput = &functionQuery{Input: b.firstInput, Func: stringLengthFunc(arg1)}
237	case "normalize-space":
238		if len(root.Args) == 0 {
239			return nil, errors.New("xpath: normalize-space function must have at least one parameter")
240		}
241		argQuery, err := b.processNode(root.Args[0])
242		if err != nil {
243			return nil, err
244		}
245		qyOutput = &functionQuery{Input: argQuery, Func: normalizespaceFunc}
246	case "translate":
247		//translate( string , string, string )
248		if len(root.Args) != 3 {
249			return nil, errors.New("xpath: translate function must have three parameters")
250		}
251		var (
252			arg1, arg2, arg3 query
253			err              error
254		)
255		if arg1, err = b.processNode(root.Args[0]); err != nil {
256			return nil, err
257		}
258		if arg2, err = b.processNode(root.Args[1]); err != nil {
259			return nil, err
260		}
261		if arg3, err = b.processNode(root.Args[2]); err != nil {
262			return nil, err
263		}
264		qyOutput = &functionQuery{Input: b.firstInput, Func: translateFunc(arg1, arg2, arg3)}
265	case "not":
266		if len(root.Args) == 0 {
267			return nil, errors.New("xpath: not function must have at least one parameter")
268		}
269		argQuery, err := b.processNode(root.Args[0])
270		if err != nil {
271			return nil, err
272		}
273		qyOutput = &functionQuery{Input: argQuery, Func: notFunc}
274	case "name", "local-name", "namespace-uri":
275		inp := b.firstInput
276		if len(root.Args) > 1 {
277			return nil, fmt.Errorf("xpath: %s function must have at most one parameter", root.FuncName)
278		}
279		if len(root.Args) == 1 {
280			argQuery, err := b.processNode(root.Args[0])
281			if err != nil {
282				return nil, err
283			}
284			inp = argQuery
285		}
286		f := &functionQuery{Input: inp}
287		switch root.FuncName {
288		case "name":
289			f.Func = nameFunc
290		case "local-name":
291			f.Func = localNameFunc
292		case "namespace-uri":
293			f.Func = namespaceFunc
294		}
295		qyOutput = f
296	case "true", "false":
297		val := root.FuncName == "true"
298		qyOutput = &functionQuery{
299			Input: b.firstInput,
300			Func: func(_ query, _ iterator) interface{} {
301				return val
302			},
303		}
304	case "last":
305		qyOutput = &functionQuery{Input: b.firstInput, Func: lastFunc}
306	case "position":
307		qyOutput = &functionQuery{Input: b.firstInput, Func: positionFunc}
308	case "boolean", "number", "string":
309		inp := b.firstInput
310		if len(root.Args) > 1 {
311			return nil, fmt.Errorf("xpath: %s function must have at most one parameter", root.FuncName)
312		}
313		if len(root.Args) == 1 {
314			argQuery, err := b.processNode(root.Args[0])
315			if err != nil {
316				return nil, err
317			}
318			inp = argQuery
319		}
320		f := &functionQuery{Input: inp}
321		switch root.FuncName {
322		case "boolean":
323			f.Func = booleanFunc
324		case "string":
325			f.Func = stringFunc
326		case "number":
327			f.Func = numberFunc
328		}
329		qyOutput = f
330	case "count":
331		//if b.firstInput == nil {
332		//	return nil, errors.New("xpath: expression must evaluate to node-set")
333		//}
334		if len(root.Args) == 0 {
335			return nil, fmt.Errorf("xpath: count(node-sets) function must with have parameters node-sets")
336		}
337		argQuery, err := b.processNode(root.Args[0])
338		if err != nil {
339			return nil, err
340		}
341		qyOutput = &functionQuery{Input: argQuery, Func: countFunc}
342	case "sum":
343		if len(root.Args) == 0 {
344			return nil, fmt.Errorf("xpath: sum(node-sets) function must with have parameters node-sets")
345		}
346		argQuery, err := b.processNode(root.Args[0])
347		if err != nil {
348			return nil, err
349		}
350		qyOutput = &functionQuery{Input: argQuery, Func: sumFunc}
351	case "ceiling", "floor", "round":
352		if len(root.Args) == 0 {
353			return nil, fmt.Errorf("xpath: ceiling(node-sets) function must with have parameters node-sets")
354		}
355		argQuery, err := b.processNode(root.Args[0])
356		if err != nil {
357			return nil, err
358		}
359		f := &functionQuery{Input: argQuery}
360		switch root.FuncName {
361		case "ceiling":
362			f.Func = ceilingFunc
363		case "floor":
364			f.Func = floorFunc
365		case "round":
366			f.Func = roundFunc
367		}
368		qyOutput = f
369	case "concat":
370		if len(root.Args) < 2 {
371			return nil, fmt.Errorf("xpath: concat() must have at least two arguments")
372		}
373		var args []query
374		for _, v := range root.Args {
375			q, err := b.processNode(v)
376			if err != nil {
377				return nil, err
378			}
379			args = append(args, q)
380		}
381		qyOutput = &functionQuery{Input: b.firstInput, Func: concatFunc(args...)}
382	default:
383		return nil, fmt.Errorf("not yet support this function %s()", root.FuncName)
384	}
385	return qyOutput, nil
386}
387
388func (b *builder) processOperatorNode(root *operatorNode) (query, error) {
389	left, err := b.processNode(root.Left)
390	if err != nil {
391		return nil, err
392	}
393	right, err := b.processNode(root.Right)
394	if err != nil {
395		return nil, err
396	}
397	var qyOutput query
398	switch root.Op {
399	case "+", "-", "div", "mod": // Numeric operator
400		var exprFunc func(interface{}, interface{}) interface{}
401		switch root.Op {
402		case "+":
403			exprFunc = plusFunc
404		case "-":
405			exprFunc = minusFunc
406		case "div":
407			exprFunc = divFunc
408		case "mod":
409			exprFunc = modFunc
410		}
411		qyOutput = &numericQuery{Left: left, Right: right, Do: exprFunc}
412	case "=", ">", ">=", "<", "<=", "!=":
413		var exprFunc func(iterator, interface{}, interface{}) interface{}
414		switch root.Op {
415		case "=":
416			exprFunc = eqFunc
417		case ">":
418			exprFunc = gtFunc
419		case ">=":
420			exprFunc = geFunc
421		case "<":
422			exprFunc = ltFunc
423		case "<=":
424			exprFunc = leFunc
425		case "!=":
426			exprFunc = neFunc
427		}
428		qyOutput = &logicalQuery{Left: left, Right: right, Do: exprFunc}
429	case "or", "and":
430		isOr := false
431		if root.Op == "or" {
432			isOr = true
433		}
434		qyOutput = &booleanQuery{Left: left, Right: right, IsOr: isOr}
435	case "|":
436		qyOutput = &unionQuery{Left: left, Right: right}
437	}
438	return qyOutput, nil
439}
440
441func (b *builder) processNode(root node) (q query, err error) {
442	if b.depth = b.depth + 1; b.depth > 1024 {
443		err = errors.New("the xpath expressions is too complex")
444		return
445	}
446
447	switch root.Type() {
448	case nodeConstantOperand:
449		n := root.(*operandNode)
450		q = &constantQuery{Val: n.Val}
451	case nodeRoot:
452		q = &contextQuery{Root: true}
453	case nodeAxis:
454		q, err = b.processAxisNode(root.(*axisNode))
455		b.firstInput = q
456	case nodeFilter:
457		q, err = b.processFilterNode(root.(*filterNode))
458	case nodeFunction:
459		q, err = b.processFunctionNode(root.(*functionNode))
460	case nodeOperator:
461		q, err = b.processOperatorNode(root.(*operatorNode))
462	}
463	return
464}
465
466// build builds a specified XPath expressions expr.
467func build(expr string) (q query, err error) {
468	defer func() {
469		if e := recover(); e != nil {
470			switch x := e.(type) {
471			case string:
472				err = errors.New(x)
473			case error:
474				err = x
475			default:
476				err = errors.New("unknown panic")
477			}
478		}
479	}()
480	root := parse(expr)
481	b := &builder{}
482	return b.processNode(root)
483}
484