1package xpath
2
3import (
4	"errors"
5	"fmt"
6	"math"
7	"strconv"
8	"strings"
9	"sync"
10	"unicode"
11)
12
13// Defined an interface of stringBuilder that compatible with
14// strings.Builder(go 1.10) and bytes.Buffer(< go 1.10)
15type stringBuilder interface {
16	WriteRune(r rune) (n int, err error)
17	WriteString(s string) (int, error)
18	Reset()
19	Grow(n int)
20	String() string
21}
22
23var builderPool = sync.Pool{New: func() interface{} {
24	return newStringBuilder()
25}}
26
27// The XPath function list.
28
29func predicate(q query) func(NodeNavigator) bool {
30	type Predicater interface {
31		Test(NodeNavigator) bool
32	}
33	if p, ok := q.(Predicater); ok {
34		return p.Test
35	}
36	return func(NodeNavigator) bool { return true }
37}
38
39// positionFunc is a XPath Node Set functions position().
40func positionFunc(q query, t iterator) interface{} {
41	var (
42		count = 1
43		node  = t.Current().Copy()
44	)
45	test := predicate(q)
46	for node.MoveToPrevious() {
47		if test(node) {
48			count++
49		}
50	}
51	return float64(count)
52}
53
54// lastFunc is a XPath Node Set functions last().
55func lastFunc(q query, t iterator) interface{} {
56	var (
57		count = 0
58		node  = t.Current().Copy()
59	)
60	node.MoveToFirst()
61	test := predicate(q)
62	for {
63		if test(node) {
64			count++
65		}
66		if !node.MoveToNext() {
67			break
68		}
69	}
70	return float64(count)
71}
72
73// countFunc is a XPath Node Set functions count(node-set).
74func countFunc(q query, t iterator) interface{} {
75	var count = 0
76	q = functionArgs(q)
77	test := predicate(q)
78	switch typ := q.Evaluate(t).(type) {
79	case query:
80		for node := typ.Select(t); node != nil; node = typ.Select(t) {
81			if test(node) {
82				count++
83			}
84		}
85	}
86	return float64(count)
87}
88
89// sumFunc is a XPath Node Set functions sum(node-set).
90func sumFunc(q query, t iterator) interface{} {
91	var sum float64
92	switch typ := functionArgs(q).Evaluate(t).(type) {
93	case query:
94		for node := typ.Select(t); node != nil; node = typ.Select(t) {
95			if v, err := strconv.ParseFloat(node.Value(), 64); err == nil {
96				sum += v
97			}
98		}
99	case float64:
100		sum = typ
101	case string:
102		v, err := strconv.ParseFloat(typ, 64)
103		if err != nil {
104			panic(errors.New("sum() function argument type must be a node-set or number"))
105		}
106		sum = v
107	}
108	return sum
109}
110
111func asNumber(t iterator, o interface{}) float64 {
112	switch typ := o.(type) {
113	case query:
114		node := typ.Select(t)
115		if node == nil {
116			return float64(0)
117		}
118		if v, err := strconv.ParseFloat(node.Value(), 64); err == nil {
119			return v
120		}
121	case float64:
122		return typ
123	case string:
124		v, err := strconv.ParseFloat(typ, 64)
125		if err != nil {
126			panic(errors.New("ceiling() function argument type must be a node-set or number"))
127		}
128		return v
129	}
130	return 0
131}
132
133// ceilingFunc is a XPath Node Set functions ceiling(node-set).
134func ceilingFunc(q query, t iterator) interface{} {
135	val := asNumber(t, functionArgs(q).Evaluate(t))
136	return math.Ceil(val)
137}
138
139// floorFunc is a XPath Node Set functions floor(node-set).
140func floorFunc(q query, t iterator) interface{} {
141	val := asNumber(t, functionArgs(q).Evaluate(t))
142	return math.Floor(val)
143}
144
145// roundFunc is a XPath Node Set functions round(node-set).
146func roundFunc(q query, t iterator) interface{} {
147	val := asNumber(t, functionArgs(q).Evaluate(t))
148	//return math.Round(val)
149	return round(val)
150}
151
152// nameFunc is a XPath functions name([node-set]).
153func nameFunc(arg query) func(query, iterator) interface{} {
154	return func(q query, t iterator) interface{} {
155		var v NodeNavigator
156		if arg == nil {
157			v = t.Current()
158		} else {
159			v = arg.Clone().Select(t)
160			if v == nil {
161				return ""
162			}
163		}
164		ns := v.Prefix()
165		if ns == "" {
166			return v.LocalName()
167		}
168		return ns + ":" + v.LocalName()
169	}
170}
171
172// localNameFunc is a XPath functions local-name([node-set]).
173func localNameFunc(arg query) func(query, iterator) interface{} {
174	return func(q query, t iterator) interface{} {
175		var v NodeNavigator
176		if arg == nil {
177			v = t.Current()
178		} else {
179			v = arg.Clone().Select(t)
180			if v == nil {
181				return ""
182			}
183		}
184		return v.LocalName()
185	}
186}
187
188// namespaceFunc is a XPath functions namespace-uri([node-set]).
189func namespaceFunc(arg query) func(query, iterator) interface{} {
190	return func(q query, t iterator) interface{} {
191		var v NodeNavigator
192		if arg == nil {
193			v = t.Current()
194		} else {
195			// Get the first node in the node-set if specified.
196			v = arg.Clone().Select(t)
197			if v == nil {
198				return ""
199			}
200		}
201		// fix about namespace-uri() bug: https://github.com/antchfx/xmlquery/issues/22
202		// TODO: In the next version, add NamespaceURL() to the NodeNavigator interface.
203		type namespaceURL interface {
204			NamespaceURL() string
205		}
206		if f, ok := v.(namespaceURL); ok {
207			return f.NamespaceURL()
208		}
209		return v.Prefix()
210	}
211}
212
213func asBool(t iterator, v interface{}) bool {
214	switch v := v.(type) {
215	case nil:
216		return false
217	case *NodeIterator:
218		return v.MoveNext()
219	case bool:
220		return v
221	case float64:
222		return v != 0
223	case string:
224		return v != ""
225	case query:
226		return v.Select(t) != nil
227	default:
228		panic(fmt.Errorf("unexpected type: %T", v))
229	}
230}
231
232func asString(t iterator, v interface{}) string {
233	switch v := v.(type) {
234	case nil:
235		return ""
236	case bool:
237		if v {
238			return "true"
239		}
240		return "false"
241	case float64:
242		return strconv.FormatFloat(v, 'g', -1, 64)
243	case string:
244		return v
245	case query:
246		node := v.Select(t)
247		if node == nil {
248			return ""
249		}
250		return node.Value()
251	default:
252		panic(fmt.Errorf("unexpected type: %T", v))
253	}
254}
255
256// booleanFunc is a XPath functions boolean([node-set]).
257func booleanFunc(q query, t iterator) interface{} {
258	v := functionArgs(q).Evaluate(t)
259	return asBool(t, v)
260}
261
262// numberFunc is a XPath functions number([node-set]).
263func numberFunc(q query, t iterator) interface{} {
264	v := functionArgs(q).Evaluate(t)
265	return asNumber(t, v)
266}
267
268// stringFunc is a XPath functions string([node-set]).
269func stringFunc(q query, t iterator) interface{} {
270	v := functionArgs(q).Evaluate(t)
271	return asString(t, v)
272}
273
274// startwithFunc is a XPath functions starts-with(string, string).
275func startwithFunc(arg1, arg2 query) func(query, iterator) interface{} {
276	return func(q query, t iterator) interface{} {
277		var (
278			m, n string
279			ok   bool
280		)
281		switch typ := functionArgs(arg1).Evaluate(t).(type) {
282		case string:
283			m = typ
284		case query:
285			node := typ.Select(t)
286			if node == nil {
287				return false
288			}
289			m = node.Value()
290		default:
291			panic(errors.New("starts-with() function argument type must be string"))
292		}
293		n, ok = functionArgs(arg2).Evaluate(t).(string)
294		if !ok {
295			panic(errors.New("starts-with() function argument type must be string"))
296		}
297		return strings.HasPrefix(m, n)
298	}
299}
300
301// endwithFunc is a XPath functions ends-with(string, string).
302func endwithFunc(arg1, arg2 query) func(query, iterator) interface{} {
303	return func(q query, t iterator) interface{} {
304		var (
305			m, n string
306			ok   bool
307		)
308		switch typ := functionArgs(arg1).Evaluate(t).(type) {
309		case string:
310			m = typ
311		case query:
312			node := typ.Select(t)
313			if node == nil {
314				return false
315			}
316			m = node.Value()
317		default:
318			panic(errors.New("ends-with() function argument type must be string"))
319		}
320		n, ok = functionArgs(arg2).Evaluate(t).(string)
321		if !ok {
322			panic(errors.New("ends-with() function argument type must be string"))
323		}
324		return strings.HasSuffix(m, n)
325	}
326}
327
328// containsFunc is a XPath functions contains(string or @attr, string).
329func containsFunc(arg1, arg2 query) func(query, iterator) interface{} {
330	return func(q query, t iterator) interface{} {
331		var (
332			m, n string
333			ok   bool
334		)
335		switch typ := functionArgs(arg1).Evaluate(t).(type) {
336		case string:
337			m = typ
338		case query:
339			node := typ.Select(t)
340			if node == nil {
341				return false
342			}
343			m = node.Value()
344		default:
345			panic(errors.New("contains() function argument type must be string"))
346		}
347
348		n, ok = functionArgs(arg2).Evaluate(t).(string)
349		if !ok {
350			panic(errors.New("contains() function argument type must be string"))
351		}
352
353		return strings.Contains(m, n)
354	}
355}
356
357// matchesFunc is an XPath function that tests a given string against a regexp pattern.
358// Note: does not support https://www.w3.org/TR/xpath-functions-31/#func-matches 3rd optional `flags` argument; if
359// needed, directly put flags in the regexp pattern, such as `(?i)^pattern$` for `i` flag.
360func matchesFunc(arg1, arg2 query) func(query, iterator) interface{} {
361	return func(q query, t iterator) interface{} {
362		var s string
363		switch typ := functionArgs(arg1).Evaluate(t).(type) {
364		case string:
365			s = typ
366		case query:
367			node := typ.Select(t)
368			if node == nil {
369				return ""
370			}
371			s = node.Value()
372		}
373		var pattern string
374		var ok bool
375		if pattern, ok = functionArgs(arg2).Evaluate(t).(string); !ok {
376			panic(errors.New("matches() function second argument type must be string"))
377		}
378		re, err := getRegexp(pattern)
379		if err != nil {
380			panic(fmt.Errorf("matches() function second argument is not a valid regexp pattern, err: %s", err.Error()))
381		}
382		return re.MatchString(s)
383	}
384}
385
386// normalizespaceFunc is XPath functions normalize-space(string?)
387func normalizespaceFunc(q query, t iterator) interface{} {
388	var m string
389	switch typ := functionArgs(q).Evaluate(t).(type) {
390	case string:
391		m = typ
392	case query:
393		node := typ.Select(t)
394		if node == nil {
395			return ""
396		}
397		m = node.Value()
398	}
399	var b = builderPool.Get().(stringBuilder)
400	b.Grow(len(m))
401
402	runeStr := []rune(strings.TrimSpace(m))
403	l := len(runeStr)
404	for i := range runeStr {
405		r := runeStr[i]
406		isSpace := unicode.IsSpace(r)
407		if !(isSpace && (i+1 < l && unicode.IsSpace(runeStr[i+1]))) {
408			if isSpace {
409				r = ' '
410			}
411			b.WriteRune(r)
412		}
413	}
414	result := b.String()
415	b.Reset()
416	builderPool.Put(b)
417
418	return result
419}
420
421// substringFunc is XPath functions substring function returns a part of a given string.
422func substringFunc(arg1, arg2, arg3 query) func(query, iterator) interface{} {
423	return func(q query, t iterator) interface{} {
424		var m string
425		switch typ := functionArgs(arg1).Evaluate(t).(type) {
426		case string:
427			m = typ
428		case query:
429			node := typ.Select(t)
430			if node == nil {
431				return ""
432			}
433			m = node.Value()
434		}
435
436		var start, length float64
437		var ok bool
438
439		if start, ok = functionArgs(arg2).Evaluate(t).(float64); !ok {
440			panic(errors.New("substring() function first argument type must be int"))
441		} else if start < 1 {
442			panic(errors.New("substring() function first argument type must be >= 1"))
443		}
444		start--
445		if arg3 != nil {
446			if length, ok = functionArgs(arg3).Evaluate(t).(float64); !ok {
447				panic(errors.New("substring() function second argument type must be int"))
448			}
449		}
450		if (len(m) - int(start)) < int(length) {
451			panic(errors.New("substring() function start and length argument out of range"))
452		}
453		if length > 0 {
454			return m[int(start):int(length+start)]
455		}
456		return m[int(start):]
457	}
458}
459
460// substringIndFunc is XPath functions substring-before/substring-after function returns a part of a given string.
461func substringIndFunc(arg1, arg2 query, after bool) func(query, iterator) interface{} {
462	return func(q query, t iterator) interface{} {
463		var str string
464		switch v := functionArgs(arg1).Evaluate(t).(type) {
465		case string:
466			str = v
467		case query:
468			node := v.Select(t)
469			if node == nil {
470				return ""
471			}
472			str = node.Value()
473		}
474		var word string
475		switch v := functionArgs(arg2).Evaluate(t).(type) {
476		case string:
477			word = v
478		case query:
479			node := v.Select(t)
480			if node == nil {
481				return ""
482			}
483			word = node.Value()
484		}
485		if word == "" {
486			return ""
487		}
488
489		i := strings.Index(str, word)
490		if i < 0 {
491			return ""
492		}
493		if after {
494			return str[i+len(word):]
495		}
496		return str[:i]
497	}
498}
499
500// stringLengthFunc is XPATH string-length( [string] ) function that returns a number
501// equal to the number of characters in a given string.
502func stringLengthFunc(arg1 query) func(query, iterator) interface{} {
503	return func(q query, t iterator) interface{} {
504		switch v := functionArgs(arg1).Evaluate(t).(type) {
505		case string:
506			return float64(len(v))
507		case query:
508			node := v.Select(t)
509			if node == nil {
510				break
511			}
512			return float64(len(node.Value()))
513		}
514		return float64(0)
515	}
516}
517
518// translateFunc is XPath functions translate() function returns a replaced string.
519func translateFunc(arg1, arg2, arg3 query) func(query, iterator) interface{} {
520	return func(q query, t iterator) interface{} {
521		str := asString(t, functionArgs(arg1).Evaluate(t))
522		src := asString(t, functionArgs(arg2).Evaluate(t))
523		dst := asString(t, functionArgs(arg3).Evaluate(t))
524
525		replace := make([]string, 0, len(src))
526		for i, s := range src {
527			d := ""
528			if i < len(dst) {
529				d = string(dst[i])
530			}
531			replace = append(replace, string(s), d)
532		}
533		return strings.NewReplacer(replace...).Replace(str)
534	}
535}
536
537// replaceFunc is XPath functions replace() function returns a replaced string.
538func replaceFunc(arg1, arg2, arg3 query) func(query, iterator) interface{} {
539	return func(q query, t iterator) interface{} {
540		str := asString(t, functionArgs(arg1).Evaluate(t))
541		src := asString(t, functionArgs(arg2).Evaluate(t))
542		dst := asString(t, functionArgs(arg3).Evaluate(t))
543
544		return strings.Replace(str, src, dst, -1)
545	}
546}
547
548// notFunc is XPATH functions not(expression) function operation.
549func notFunc(q query, t iterator) interface{} {
550	switch v := functionArgs(q).Evaluate(t).(type) {
551	case bool:
552		return !v
553	case query:
554		node := v.Select(t)
555		return node == nil
556	default:
557		return false
558	}
559}
560
561// concatFunc is the concat function concatenates two or more
562// strings and returns the resulting string.
563// concat( string1 , string2 [, stringn]* )
564func concatFunc(args ...query) func(query, iterator) interface{} {
565	return func(q query, t iterator) interface{} {
566		b := builderPool.Get().(stringBuilder)
567		for _, v := range args {
568			v = functionArgs(v)
569
570			switch v := v.Evaluate(t).(type) {
571			case string:
572				b.WriteString(v)
573			case query:
574				node := v.Select(t)
575				if node != nil {
576					b.WriteString(node.Value())
577				}
578			}
579		}
580		result := b.String()
581		b.Reset()
582		builderPool.Put(b)
583
584		return result
585	}
586}
587
588// https://github.com/antchfx/xpath/issues/43
589func functionArgs(q query) query {
590	if _, ok := q.(*functionQuery); ok {
591		return q
592	}
593	return q.Clone()
594}
595
596func reverseFunc(q query, t iterator) func() NodeNavigator {
597	var list []NodeNavigator
598	for {
599		node := q.Select(t)
600		if node == nil {
601			break
602		}
603		list = append(list, node.Copy())
604	}
605	i := len(list)
606	return func() NodeNavigator {
607		if i <= 0 {
608			return nil
609		}
610		i--
611		node := list[i]
612		return node
613	}
614}
615