1package xpath
2
3import (
4	"errors"
5	"fmt"
6)
7
8type flag int
9
10const (
11	noneFlag flag = iota
12	filterFlag
13)
14
15// builder provides building an XPath expressions.
16type builder struct {
17	depth      int
18	flag       flag
19	firstInput query
20}
21
22// axisPredicate creates a predicate to predicating for this axis node.
23func axisPredicate(root *axisNode) func(NodeNavigator) bool {
24	// get current axix node type.
25	typ := ElementNode
26	switch root.AxeType {
27	case "attribute":
28		typ = AttributeNode
29	case "self", "parent":
30		typ = allNode
31	default:
32		switch root.Prop {
33		case "comment":
34			typ = CommentNode
35		case "text":
36			typ = TextNode
37			//	case "processing-instruction":
38		//	typ = ProcessingInstructionNode
39		case "node":
40			typ = allNode
41		}
42	}
43	nametest := root.LocalName != "" || root.Prefix != ""
44	predicate := func(n NodeNavigator) bool {
45		if typ == n.NodeType() || typ == allNode || typ == TextNode {
46			if nametest {
47				if root.LocalName == n.LocalName() && root.Prefix == n.Prefix() {
48					return true
49				}
50			} else {
51				return true
52			}
53		}
54		return false
55	}
56
57	return predicate
58}
59
60// processAxisNode processes a query for the XPath axis node.
61func (b *builder) processAxisNode(root *axisNode) (query, error) {
62	var (
63		err       error
64		qyInput   query
65		qyOutput  query
66		predicate = axisPredicate(root)
67	)
68
69	if root.Input == nil {
70		qyInput = &contextQuery{}
71	} else {
72		if root.AxeType == "child" && (root.Input.Type() == nodeAxis) {
73			if input := root.Input.(*axisNode); input.AxeType == "descendant-or-self" {
74				var qyGrandInput query
75				if input.Input != nil {
76					qyGrandInput, _ = b.processNode(input.Input)
77				} else {
78					qyGrandInput = &contextQuery{}
79				}
80				// fix #20: https://github.com/antchfx/htmlquery/issues/20
81				filter := func(n NodeNavigator) bool {
82					v := predicate(n)
83					switch root.Prop {
84					case "text":
85						v = v && n.NodeType() == TextNode
86					case "comment":
87						v = v && n.NodeType() == CommentNode
88					}
89					return v
90				}
91				qyOutput = &descendantQuery{Input: qyGrandInput, Predicate: filter, Self: true}
92				return qyOutput, nil
93			}
94		}
95		qyInput, err = b.processNode(root.Input)
96		if err != nil {
97			return nil, err
98		}
99	}
100
101	switch root.AxeType {
102	case "ancestor":
103		qyOutput = &ancestorQuery{Input: qyInput, Predicate: predicate}
104	case "ancestor-or-self":
105		qyOutput = &ancestorQuery{Input: qyInput, Predicate: predicate, Self: true}
106	case "attribute":
107		qyOutput = &attributeQuery{Input: qyInput, Predicate: predicate}
108	case "child":
109		filter := func(n NodeNavigator) bool {
110			v := predicate(n)
111			switch root.Prop {
112			case "text":
113				v = v && n.NodeType() == TextNode
114			case "node":
115				v = v && (n.NodeType() == ElementNode || n.NodeType() == TextNode)
116			case "comment":
117				v = v && n.NodeType() == CommentNode
118			}
119			return v
120		}
121		qyOutput = &childQuery{Input: qyInput, Predicate: filter}
122	case "descendant":
123		qyOutput = &descendantQuery{Input: qyInput, Predicate: predicate}
124	case "descendant-or-self":
125		qyOutput = &descendantQuery{Input: qyInput, Predicate: predicate, Self: true}
126	case "following":
127		qyOutput = &followingQuery{Input: qyInput, Predicate: predicate}
128	case "following-sibling":
129		qyOutput = &followingQuery{Input: qyInput, Predicate: predicate, Sibling: true}
130	case "parent":
131		qyOutput = &parentQuery{Input: qyInput, Predicate: predicate}
132	case "preceding":
133		qyOutput = &precedingQuery{Input: qyInput, Predicate: predicate}
134	case "preceding-sibling":
135		qyOutput = &precedingQuery{Input: qyInput, Predicate: predicate, Sibling: true}
136	case "self":
137		qyOutput = &selfQuery{Input: qyInput, Predicate: predicate}
138	case "namespace":
139		// haha,what will you do someting??
140	default:
141		err = fmt.Errorf("unknown axe type: %s", root.AxeType)
142		return nil, err
143	}
144	return qyOutput, nil
145}
146
147// processFilterNode builds query for the XPath filter predicate.
148func (b *builder) processFilterNode(root *filterNode) (query, error) {
149	b.flag |= filterFlag
150
151	qyInput, err := b.processNode(root.Input)
152	if err != nil {
153		return nil, err
154	}
155	qyCond, err := b.processNode(root.Condition)
156	if err != nil {
157		return nil, err
158	}
159	qyOutput := &filterQuery{Input: qyInput, Predicate: qyCond}
160	return qyOutput, nil
161}
162
163// processFunctionNode processes query for the XPath function node.
164func (b *builder) processFunctionNode(root *functionNode) (query, error) {
165	var qyOutput query
166	switch root.FuncName {
167	case "starts-with":
168		arg1, err := b.processNode(root.Args[0])
169		if err != nil {
170			return nil, err
171		}
172		arg2, err := b.processNode(root.Args[1])
173		if err != nil {
174			return nil, err
175		}
176		qyOutput = &functionQuery{Input: b.firstInput, Func: startwithFunc(arg1, arg2)}
177	case "ends-with":
178		arg1, err := b.processNode(root.Args[0])
179		if err != nil {
180			return nil, err
181		}
182		arg2, err := b.processNode(root.Args[1])
183		if err != nil {
184			return nil, err
185		}
186		qyOutput = &functionQuery{Input: b.firstInput, Func: endwithFunc(arg1, arg2)}
187	case "contains":
188		arg1, err := b.processNode(root.Args[0])
189		if err != nil {
190			return nil, err
191		}
192		arg2, err := b.processNode(root.Args[1])
193		if err != nil {
194			return nil, err
195		}
196
197		qyOutput = &functionQuery{Input: b.firstInput, Func: containsFunc(arg1, arg2)}
198	case "substring":
199		//substring( string , start [, length] )
200		if len(root.Args) < 2 {
201			return nil, errors.New("xpath: substring function must have at least two parameter")
202		}
203		var (
204			arg1, arg2, arg3 query
205			err              error
206		)
207		if arg1, err = b.processNode(root.Args[0]); err != nil {
208			return nil, err
209		}
210		if arg2, err = b.processNode(root.Args[1]); err != nil {
211			return nil, err
212		}
213		if len(root.Args) == 3 {
214			if arg3, err = b.processNode(root.Args[2]); err != nil {
215				return nil, err
216			}
217		}
218		qyOutput = &functionQuery{Input: b.firstInput, Func: substringFunc(arg1, arg2, arg3)}
219	case "substring-before", "substring-after":
220		//substring-xxxx( haystack, needle )
221		if len(root.Args) != 2 {
222			return nil, errors.New("xpath: substring-before function must have two parameters")
223		}
224		var (
225			arg1, arg2 query
226			err        error
227		)
228		if arg1, err = b.processNode(root.Args[0]); err != nil {
229			return nil, err
230		}
231		if arg2, err = b.processNode(root.Args[1]); err != nil {
232			return nil, err
233		}
234		qyOutput = &functionQuery{
235			Input: b.firstInput,
236			Func:  substringIndFunc(arg1, arg2, root.FuncName == "substring-after"),
237		}
238	case "string-length":
239		// string-length( [string] )
240		if len(root.Args) < 1 {
241			return nil, errors.New("xpath: string-length function must have at least one parameter")
242		}
243		arg1, err := b.processNode(root.Args[0])
244		if err != nil {
245			return nil, err
246		}
247		qyOutput = &functionQuery{Input: b.firstInput, Func: stringLengthFunc(arg1)}
248	case "normalize-space":
249		if len(root.Args) == 0 {
250			return nil, errors.New("xpath: normalize-space function must have at least one parameter")
251		}
252		argQuery, err := b.processNode(root.Args[0])
253		if err != nil {
254			return nil, err
255		}
256		qyOutput = &functionQuery{Input: argQuery, Func: normalizespaceFunc}
257	case "replace":
258		//replace( string , string, string )
259		if len(root.Args) != 3 {
260			return nil, errors.New("xpath: replace function must have three parameters")
261		}
262		var (
263			arg1, arg2, arg3 query
264			err              error
265		)
266		if arg1, err = b.processNode(root.Args[0]); err != nil {
267			return nil, err
268		}
269		if arg2, err = b.processNode(root.Args[1]); err != nil {
270			return nil, err
271		}
272		if arg3, err = b.processNode(root.Args[2]); err != nil {
273			return nil, err
274		}
275		qyOutput = &functionQuery{Input: b.firstInput, Func: replaceFunc(arg1, arg2, arg3)}
276	case "translate":
277		//translate( string , string, string )
278		if len(root.Args) != 3 {
279			return nil, errors.New("xpath: translate function must have three parameters")
280		}
281		var (
282			arg1, arg2, arg3 query
283			err              error
284		)
285		if arg1, err = b.processNode(root.Args[0]); err != nil {
286			return nil, err
287		}
288		if arg2, err = b.processNode(root.Args[1]); err != nil {
289			return nil, err
290		}
291		if arg3, err = b.processNode(root.Args[2]); err != nil {
292			return nil, err
293		}
294		qyOutput = &functionQuery{Input: b.firstInput, Func: translateFunc(arg1, arg2, arg3)}
295	case "not":
296		if len(root.Args) == 0 {
297			return nil, errors.New("xpath: not function must have at least one parameter")
298		}
299		argQuery, err := b.processNode(root.Args[0])
300		if err != nil {
301			return nil, err
302		}
303		qyOutput = &functionQuery{Input: argQuery, Func: notFunc}
304	case "name", "local-name", "namespace-uri":
305		if len(root.Args) > 1 {
306			return nil, fmt.Errorf("xpath: %s function must have at most one parameter", root.FuncName)
307		}
308		var (
309			arg query
310			err error
311		)
312		if len(root.Args) == 1 {
313			arg, err = b.processNode(root.Args[0])
314			if err != nil {
315				return nil, err
316			}
317		}
318		switch root.FuncName {
319		case "name":
320			qyOutput = &functionQuery{Input: b.firstInput, Func: nameFunc(arg)}
321		case "local-name":
322			qyOutput = &functionQuery{Input: b.firstInput, Func: localNameFunc(arg)}
323		case "namespace-uri":
324			qyOutput = &functionQuery{Input: b.firstInput, Func: namespaceFunc(arg)}
325		}
326	case "true", "false":
327		val := root.FuncName == "true"
328		qyOutput = &functionQuery{
329			Input: b.firstInput,
330			Func: func(_ query, _ iterator) interface{} {
331				return val
332			},
333		}
334	case "last":
335		qyOutput = &functionQuery{Input: b.firstInput, Func: lastFunc}
336	case "position":
337		qyOutput = &functionQuery{Input: b.firstInput, Func: positionFunc}
338	case "boolean", "number", "string":
339		inp := b.firstInput
340		if len(root.Args) > 1 {
341			return nil, fmt.Errorf("xpath: %s function must have at most one parameter", root.FuncName)
342		}
343		if len(root.Args) == 1 {
344			argQuery, err := b.processNode(root.Args[0])
345			if err != nil {
346				return nil, err
347			}
348			inp = argQuery
349		}
350		f := &functionQuery{Input: inp}
351		switch root.FuncName {
352		case "boolean":
353			f.Func = booleanFunc
354		case "string":
355			f.Func = stringFunc
356		case "number":
357			f.Func = numberFunc
358		}
359		qyOutput = f
360	case "count":
361		//if b.firstInput == nil {
362		//	return nil, errors.New("xpath: expression must evaluate to node-set")
363		//}
364		if len(root.Args) == 0 {
365			return nil, fmt.Errorf("xpath: count(node-sets) function must with have parameters node-sets")
366		}
367		argQuery, err := b.processNode(root.Args[0])
368		if err != nil {
369			return nil, err
370		}
371		qyOutput = &functionQuery{Input: argQuery, Func: countFunc}
372	case "sum":
373		if len(root.Args) == 0 {
374			return nil, fmt.Errorf("xpath: sum(node-sets) function must with have parameters node-sets")
375		}
376		argQuery, err := b.processNode(root.Args[0])
377		if err != nil {
378			return nil, err
379		}
380		qyOutput = &functionQuery{Input: argQuery, Func: sumFunc}
381	case "ceiling", "floor", "round":
382		if len(root.Args) == 0 {
383			return nil, fmt.Errorf("xpath: ceiling(node-sets) function must with have parameters node-sets")
384		}
385		argQuery, err := b.processNode(root.Args[0])
386		if err != nil {
387			return nil, err
388		}
389		f := &functionQuery{Input: argQuery}
390		switch root.FuncName {
391		case "ceiling":
392			f.Func = ceilingFunc
393		case "floor":
394			f.Func = floorFunc
395		case "round":
396			f.Func = roundFunc
397		}
398		qyOutput = f
399	case "concat":
400		if len(root.Args) < 2 {
401			return nil, fmt.Errorf("xpath: concat() must have at least two arguments")
402		}
403		var args []query
404		for _, v := range root.Args {
405			q, err := b.processNode(v)
406			if err != nil {
407				return nil, err
408			}
409			args = append(args, q)
410		}
411		qyOutput = &functionQuery{Input: b.firstInput, Func: concatFunc(args...)}
412	case "reverse":
413		if len(root.Args) == 0 {
414			return nil, fmt.Errorf("xpath: reverse(node-sets) function must with have parameters node-sets")
415		}
416		argQuery, err := b.processNode(root.Args[0])
417		if err != nil {
418			return nil, err
419		}
420		qyOutput = &transformFunctionQuery{Input: argQuery, Func: reverseFunc}
421	default:
422		return nil, fmt.Errorf("not yet support this function %s()", root.FuncName)
423	}
424	return qyOutput, nil
425}
426
427func (b *builder) processOperatorNode(root *operatorNode) (query, error) {
428	left, err := b.processNode(root.Left)
429	if err != nil {
430		return nil, err
431	}
432	right, err := b.processNode(root.Right)
433	if err != nil {
434		return nil, err
435	}
436	var qyOutput query
437	switch root.Op {
438	case "+", "-", "div", "mod": // Numeric operator
439		var exprFunc func(interface{}, interface{}) interface{}
440		switch root.Op {
441		case "+":
442			exprFunc = plusFunc
443		case "-":
444			exprFunc = minusFunc
445		case "div":
446			exprFunc = divFunc
447		case "mod":
448			exprFunc = modFunc
449		}
450		qyOutput = &numericQuery{Left: left, Right: right, Do: exprFunc}
451	case "=", ">", ">=", "<", "<=", "!=":
452		var exprFunc func(iterator, interface{}, interface{}) interface{}
453		switch root.Op {
454		case "=":
455			exprFunc = eqFunc
456		case ">":
457			exprFunc = gtFunc
458		case ">=":
459			exprFunc = geFunc
460		case "<":
461			exprFunc = ltFunc
462		case "<=":
463			exprFunc = leFunc
464		case "!=":
465			exprFunc = neFunc
466		}
467		qyOutput = &logicalQuery{Left: left, Right: right, Do: exprFunc}
468	case "or", "and":
469		isOr := false
470		if root.Op == "or" {
471			isOr = true
472		}
473		qyOutput = &booleanQuery{Left: left, Right: right, IsOr: isOr}
474	case "|":
475		qyOutput = &unionQuery{Left: left, Right: right}
476	}
477	return qyOutput, nil
478}
479
480func (b *builder) processNode(root node) (q query, err error) {
481	if b.depth = b.depth + 1; b.depth > 1024 {
482		err = errors.New("the xpath expressions is too complex")
483		return
484	}
485
486	switch root.Type() {
487	case nodeConstantOperand:
488		n := root.(*operandNode)
489		q = &constantQuery{Val: n.Val}
490	case nodeRoot:
491		q = &contextQuery{Root: true}
492	case nodeAxis:
493		q, err = b.processAxisNode(root.(*axisNode))
494		b.firstInput = q
495	case nodeFilter:
496		q, err = b.processFilterNode(root.(*filterNode))
497	case nodeFunction:
498		q, err = b.processFunctionNode(root.(*functionNode))
499	case nodeOperator:
500		q, err = b.processOperatorNode(root.(*operatorNode))
501	}
502	return
503}
504
505// build builds a specified XPath expressions expr.
506func build(expr string) (q query, err error) {
507	defer func() {
508		if e := recover(); e != nil {
509			switch x := e.(type) {
510			case string:
511				err = errors.New(x)
512			case error:
513				err = x
514			default:
515				err = errors.New("unknown panic")
516			}
517		}
518	}()
519	root := parse(expr)
520	b := &builder{}
521	return b.processNode(root)
522}
523