1package xpath
2
3import (
4	"fmt"
5	"reflect"
6	"strconv"
7)
8
9// The XPath number operator function list.
10
11// valueType is a return value type.
12type valueType int
13
14const (
15	booleanType valueType = iota
16	numberType
17	stringType
18	nodeSetType
19)
20
21func getValueType(i interface{}) valueType {
22	v := reflect.ValueOf(i)
23	switch v.Kind() {
24	case reflect.Float64:
25		return numberType
26	case reflect.String:
27		return stringType
28	case reflect.Bool:
29		return booleanType
30	default:
31		if _, ok := i.(query); ok {
32			return nodeSetType
33		}
34	}
35	panic(fmt.Errorf("xpath unknown value type: %v", v.Kind()))
36}
37
38type logical func(iterator, string, interface{}, interface{}) bool
39
40var logicalFuncs = [][]logical{
41	{cmpBooleanBoolean, nil, nil, nil},
42	{nil, cmpNumericNumeric, cmpNumericString, cmpNumericNodeSet},
43	{nil, cmpStringNumeric, cmpStringString, cmpStringNodeSet},
44	{nil, cmpNodeSetNumeric, cmpNodeSetString, cmpNodeSetNodeSet},
45}
46
47// number vs number
48func cmpNumberNumberF(op string, a, b float64) bool {
49	switch op {
50	case "=":
51		return a == b
52	case ">":
53		return a > b
54	case "<":
55		return a < b
56	case ">=":
57		return a >= b
58	case "<=":
59		return a <= b
60	case "!=":
61		return a != b
62	}
63	return false
64}
65
66// string vs string
67func cmpStringStringF(op string, a, b string) bool {
68	switch op {
69	case "=":
70		return a == b
71	case ">":
72		return a > b
73	case "<":
74		return a < b
75	case ">=":
76		return a >= b
77	case "<=":
78		return a <= b
79	case "!=":
80		return a != b
81	}
82	return false
83}
84
85func cmpBooleanBooleanF(op string, a, b bool) bool {
86	switch op {
87	case "or":
88		return a || b
89	case "and":
90		return a && b
91	}
92	return false
93}
94
95func cmpNumericNumeric(t iterator, op string, m, n interface{}) bool {
96	a := m.(float64)
97	b := n.(float64)
98	return cmpNumberNumberF(op, a, b)
99}
100
101func cmpNumericString(t iterator, op string, m, n interface{}) bool {
102	a := m.(float64)
103	b := n.(string)
104	num, err := strconv.ParseFloat(b, 64)
105	if err != nil {
106		panic(err)
107	}
108	return cmpNumberNumberF(op, a, num)
109}
110
111func cmpNumericNodeSet(t iterator, op string, m, n interface{}) bool {
112	a := m.(float64)
113	b := n.(query)
114
115	for {
116		node := b.Select(t)
117		if node == nil {
118			break
119		}
120		num, err := strconv.ParseFloat(node.Value(), 64)
121		if err != nil {
122			panic(err)
123		}
124		if cmpNumberNumberF(op, a, num) {
125			return true
126		}
127	}
128	return false
129}
130
131func cmpNodeSetNumeric(t iterator, op string, m, n interface{}) bool {
132	a := m.(query)
133	b := n.(float64)
134	for {
135		node := a.Select(t)
136		if node == nil {
137			break
138		}
139		num, err := strconv.ParseFloat(node.Value(), 64)
140		if err != nil {
141			panic(err)
142		}
143		if cmpNumberNumberF(op, num, b) {
144			return true
145		}
146	}
147	return false
148}
149
150func cmpNodeSetString(t iterator, op string, m, n interface{}) bool {
151	a := m.(query)
152	b := n.(string)
153	for {
154		node := a.Select(t)
155		if node == nil {
156			break
157		}
158		if cmpStringStringF(op, b, node.Value()) {
159			return true
160		}
161	}
162	return false
163}
164
165func cmpNodeSetNodeSet(t iterator, op string, m, n interface{}) bool {
166	a := m.(query)
167	b := n.(query)
168	x := a.Select(t)
169	if x == nil {
170		return false
171	}
172	y := b.Select(t)
173	if y == nil {
174		return false
175	}
176	return cmpStringStringF(op, x.Value(), y.Value())
177}
178
179func cmpStringNumeric(t iterator, op string, m, n interface{}) bool {
180	a := m.(string)
181	b := n.(float64)
182	num, err := strconv.ParseFloat(a, 64)
183	if err != nil {
184		panic(err)
185	}
186	return cmpNumberNumberF(op, b, num)
187}
188
189func cmpStringString(t iterator, op string, m, n interface{}) bool {
190	a := m.(string)
191	b := n.(string)
192	return cmpStringStringF(op, a, b)
193}
194
195func cmpStringNodeSet(t iterator, op string, m, n interface{}) bool {
196	a := m.(string)
197	b := n.(query)
198	for {
199		node := b.Select(t)
200		if node == nil {
201			break
202		}
203		if cmpStringStringF(op, a, node.Value()) {
204			return true
205		}
206	}
207	return false
208}
209
210func cmpBooleanBoolean(t iterator, op string, m, n interface{}) bool {
211	a := m.(bool)
212	b := n.(bool)
213	return cmpBooleanBooleanF(op, a, b)
214}
215
216// eqFunc is an `=` operator.
217func eqFunc(t iterator, m, n interface{}) interface{} {
218	t1 := getValueType(m)
219	t2 := getValueType(n)
220	return logicalFuncs[t1][t2](t, "=", m, n)
221}
222
223// gtFunc is an `>` operator.
224func gtFunc(t iterator, m, n interface{}) interface{} {
225	t1 := getValueType(m)
226	t2 := getValueType(n)
227	return logicalFuncs[t1][t2](t, ">", m, n)
228}
229
230// geFunc is an `>=` operator.
231func geFunc(t iterator, m, n interface{}) interface{} {
232	t1 := getValueType(m)
233	t2 := getValueType(n)
234	return logicalFuncs[t1][t2](t, ">=", m, n)
235}
236
237// ltFunc is an `<` operator.
238func ltFunc(t iterator, m, n interface{}) interface{} {
239	t1 := getValueType(m)
240	t2 := getValueType(n)
241	return logicalFuncs[t1][t2](t, "<", m, n)
242}
243
244// leFunc is an `<=` operator.
245func leFunc(t iterator, m, n interface{}) interface{} {
246	t1 := getValueType(m)
247	t2 := getValueType(n)
248	return logicalFuncs[t1][t2](t, "<=", m, n)
249}
250
251// neFunc is an `!=` operator.
252func neFunc(t iterator, m, n interface{}) interface{} {
253	t1 := getValueType(m)
254	t2 := getValueType(n)
255	return logicalFuncs[t1][t2](t, "!=", m, n)
256}
257
258// orFunc is an `or` operator.
259var orFunc = func(t iterator, m, n interface{}) interface{} {
260	t1 := getValueType(m)
261	t2 := getValueType(n)
262	return logicalFuncs[t1][t2](t, "or", m, n)
263}
264
265func numericExpr(m, n interface{}, cb func(float64, float64) float64) float64 {
266	typ := reflect.TypeOf(float64(0))
267	a := reflect.ValueOf(m).Convert(typ)
268	b := reflect.ValueOf(n).Convert(typ)
269	return cb(a.Float(), b.Float())
270}
271
272// plusFunc is an `+` operator.
273var plusFunc = func(m, n interface{}) interface{} {
274	return numericExpr(m, n, func(a, b float64) float64 {
275		return a + b
276	})
277}
278
279// minusFunc is an `-` operator.
280var minusFunc = func(m, n interface{}) interface{} {
281	return numericExpr(m, n, func(a, b float64) float64 {
282		return a - b
283	})
284}
285
286// mulFunc is an `*` operator.
287var mulFunc = func(m, n interface{}) interface{} {
288	return numericExpr(m, n, func(a, b float64) float64 {
289		return a * b
290	})
291}
292
293// divFunc is an `DIV` operator.
294var divFunc = func(m, n interface{}) interface{} {
295	return numericExpr(m, n, func(a, b float64) float64 {
296		return a / b
297	})
298}
299
300// modFunc is an 'MOD' operator.
301var modFunc = func(m, n interface{}) interface{} {
302	return numericExpr(m, n, func(a, b float64) float64 {
303		return float64(int(a) % int(b))
304	})
305}
306