1package semver
2
3import (
4	"fmt"
5	"strconv"
6	"strings"
7	"unicode"
8)
9
10type wildcardType int
11
12const (
13	noneWildcard  wildcardType = iota
14	majorWildcard wildcardType = 1
15	minorWildcard wildcardType = 2
16	patchWildcard wildcardType = 3
17)
18
19func wildcardTypefromInt(i int) wildcardType {
20	switch i {
21	case 1:
22		return majorWildcard
23	case 2:
24		return minorWildcard
25	case 3:
26		return patchWildcard
27	default:
28		return noneWildcard
29	}
30}
31
32type comparator func(Version, Version) bool
33
34var (
35	compEQ comparator = func(v1 Version, v2 Version) bool {
36		return v1.Compare(v2) == 0
37	}
38	compNE = func(v1 Version, v2 Version) bool {
39		return v1.Compare(v2) != 0
40	}
41	compGT = func(v1 Version, v2 Version) bool {
42		return v1.Compare(v2) == 1
43	}
44	compGE = func(v1 Version, v2 Version) bool {
45		return v1.Compare(v2) >= 0
46	}
47	compLT = func(v1 Version, v2 Version) bool {
48		return v1.Compare(v2) == -1
49	}
50	compLE = func(v1 Version, v2 Version) bool {
51		return v1.Compare(v2) <= 0
52	}
53)
54
55type versionRange struct {
56	v Version
57	c comparator
58}
59
60// rangeFunc creates a Range from the given versionRange.
61func (vr *versionRange) rangeFunc() Range {
62	return Range(func(v Version) bool {
63		return vr.c(v, vr.v)
64	})
65}
66
67// Range represents a range of versions.
68// A Range can be used to check if a Version satisfies it:
69//
70//     range, err := semver.ParseRange(">1.0.0 <2.0.0")
71//     range(semver.MustParse("1.1.1") // returns true
72type Range func(Version) bool
73
74// OR combines the existing Range with another Range using logical OR.
75func (rf Range) OR(f Range) Range {
76	return Range(func(v Version) bool {
77		return rf(v) || f(v)
78	})
79}
80
81// AND combines the existing Range with another Range using logical AND.
82func (rf Range) AND(f Range) Range {
83	return Range(func(v Version) bool {
84		return rf(v) && f(v)
85	})
86}
87
88// ParseRange parses a range and returns a Range.
89// If the range could not be parsed an error is returned.
90//
91// Valid ranges are:
92//   - "<1.0.0"
93//   - "<=1.0.0"
94//   - ">1.0.0"
95//   - ">=1.0.0"
96//   - "1.0.0", "=1.0.0", "==1.0.0"
97//   - "!1.0.0", "!=1.0.0"
98//
99// A Range can consist of multiple ranges separated by space:
100// Ranges can be linked by logical AND:
101//   - ">1.0.0 <2.0.0" would match between both ranges, so "1.1.1" and "1.8.7" but not "1.0.0" or "2.0.0"
102//   - ">1.0.0 <3.0.0 !2.0.3-beta.2" would match every version between 1.0.0 and 3.0.0 except 2.0.3-beta.2
103//
104// Ranges can also be linked by logical OR:
105//   - "<2.0.0 || >=3.0.0" would match "1.x.x" and "3.x.x" but not "2.x.x"
106//
107// AND has a higher precedence than OR. It's not possible to use brackets.
108//
109// Ranges can be combined by both AND and OR
110//
111//  - `>1.0.0 <2.0.0 || >3.0.0 !4.2.1` would match `1.2.3`, `1.9.9`, `3.1.1`, but not `4.2.1`, `2.1.1`
112func ParseRange(s string) (Range, error) {
113	parts := splitAndTrim(s)
114	orParts, err := splitORParts(parts)
115	if err != nil {
116		return nil, err
117	}
118	expandedParts, err := expandWildcardVersion(orParts)
119	if err != nil {
120		return nil, err
121	}
122	var orFn Range
123	for _, p := range expandedParts {
124		var andFn Range
125		for _, ap := range p {
126			opStr, vStr, err := splitComparatorVersion(ap)
127			if err != nil {
128				return nil, err
129			}
130			vr, err := buildVersionRange(opStr, vStr)
131			if err != nil {
132				return nil, fmt.Errorf("Could not parse Range %q: %s", ap, err)
133			}
134			rf := vr.rangeFunc()
135
136			// Set function
137			if andFn == nil {
138				andFn = rf
139			} else { // Combine with existing function
140				andFn = andFn.AND(rf)
141			}
142		}
143		if orFn == nil {
144			orFn = andFn
145		} else {
146			orFn = orFn.OR(andFn)
147		}
148
149	}
150	return orFn, nil
151}
152
153// splitORParts splits the already cleaned parts by '||'.
154// Checks for invalid positions of the operator and returns an
155// error if found.
156func splitORParts(parts []string) ([][]string, error) {
157	var ORparts [][]string
158	last := 0
159	for i, p := range parts {
160		if p == "||" {
161			if i == 0 {
162				return nil, fmt.Errorf("First element in range is '||'")
163			}
164			ORparts = append(ORparts, parts[last:i])
165			last = i + 1
166		}
167	}
168	if last == len(parts) {
169		return nil, fmt.Errorf("Last element in range is '||'")
170	}
171	ORparts = append(ORparts, parts[last:])
172	return ORparts, nil
173}
174
175// buildVersionRange takes a slice of 2: operator and version
176// and builds a versionRange, otherwise an error.
177func buildVersionRange(opStr, vStr string) (*versionRange, error) {
178	c := parseComparator(opStr)
179	if c == nil {
180		return nil, fmt.Errorf("Could not parse comparator %q in %q", opStr, strings.Join([]string{opStr, vStr}, ""))
181	}
182	v, err := Parse(vStr)
183	if err != nil {
184		return nil, fmt.Errorf("Could not parse version %q in %q: %s", vStr, strings.Join([]string{opStr, vStr}, ""), err)
185	}
186
187	return &versionRange{
188		v: v,
189		c: c,
190	}, nil
191
192}
193
194// inArray checks if a byte is contained in an array of bytes
195func inArray(s byte, list []byte) bool {
196	for _, el := range list {
197		if el == s {
198			return true
199		}
200	}
201	return false
202}
203
204// splitAndTrim splits a range string by spaces and cleans whitespaces
205func splitAndTrim(s string) (result []string) {
206	last := 0
207	var lastChar byte
208	excludeFromSplit := []byte{'>', '<', '='}
209	for i := 0; i < len(s); i++ {
210		if s[i] == ' ' && !inArray(lastChar, excludeFromSplit) {
211			if last < i-1 {
212				result = append(result, s[last:i])
213			}
214			last = i + 1
215		} else if s[i] != ' ' {
216			lastChar = s[i]
217		}
218	}
219	if last < len(s)-1 {
220		result = append(result, s[last:])
221	}
222
223	for i, v := range result {
224		result[i] = strings.Replace(v, " ", "", -1)
225	}
226
227	// parts := strings.Split(s, " ")
228	// for _, x := range parts {
229	// 	if s := strings.TrimSpace(x); len(s) != 0 {
230	// 		result = append(result, s)
231	// 	}
232	// }
233	return
234}
235
236// splitComparatorVersion splits the comparator from the version.
237// Input must be free of leading or trailing spaces.
238func splitComparatorVersion(s string) (string, string, error) {
239	i := strings.IndexFunc(s, unicode.IsDigit)
240	if i == -1 {
241		return "", "", fmt.Errorf("Could not get version from string: %q", s)
242	}
243	return strings.TrimSpace(s[0:i]), s[i:], nil
244}
245
246// getWildcardType will return the type of wildcard that the
247// passed version contains
248func getWildcardType(vStr string) wildcardType {
249	parts := strings.Split(vStr, ".")
250	nparts := len(parts)
251	wildcard := parts[nparts-1]
252
253	possibleWildcardType := wildcardTypefromInt(nparts)
254	if wildcard == "x" {
255		return possibleWildcardType
256	}
257
258	return noneWildcard
259}
260
261// createVersionFromWildcard will convert a wildcard version
262// into a regular version, replacing 'x's with '0's, handling
263// special cases like '1.x.x' and '1.x'
264func createVersionFromWildcard(vStr string) string {
265	// handle 1.x.x
266	vStr2 := strings.Replace(vStr, ".x.x", ".x", 1)
267	vStr2 = strings.Replace(vStr2, ".x", ".0", 1)
268	parts := strings.Split(vStr2, ".")
269
270	// handle 1.x
271	if len(parts) == 2 {
272		return vStr2 + ".0"
273	}
274
275	return vStr2
276}
277
278// incrementMajorVersion will increment the major version
279// of the passed version
280func incrementMajorVersion(vStr string) (string, error) {
281	parts := strings.Split(vStr, ".")
282	i, err := strconv.Atoi(parts[0])
283	if err != nil {
284		return "", err
285	}
286	parts[0] = strconv.Itoa(i + 1)
287
288	return strings.Join(parts, "."), nil
289}
290
291// incrementMajorVersion will increment the minor version
292// of the passed version
293func incrementMinorVersion(vStr string) (string, error) {
294	parts := strings.Split(vStr, ".")
295	i, err := strconv.Atoi(parts[1])
296	if err != nil {
297		return "", err
298	}
299	parts[1] = strconv.Itoa(i + 1)
300
301	return strings.Join(parts, "."), nil
302}
303
304// expandWildcardVersion will expand wildcards inside versions
305// following these rules:
306//
307// * when dealing with patch wildcards:
308// >= 1.2.x    will become    >= 1.2.0
309// <= 1.2.x    will become    <  1.3.0
310// >  1.2.x    will become    >= 1.3.0
311// <  1.2.x    will become    <  1.2.0
312// != 1.2.x    will become    <  1.2.0 >= 1.3.0
313//
314// * when dealing with minor wildcards:
315// >= 1.x      will become    >= 1.0.0
316// <= 1.x      will become    <  2.0.0
317// >  1.x      will become    >= 2.0.0
318// <  1.0      will become    <  1.0.0
319// != 1.x      will become    <  1.0.0 >= 2.0.0
320//
321// * when dealing with wildcards without
322// version operator:
323// 1.2.x       will become    >= 1.2.0 < 1.3.0
324// 1.x         will become    >= 1.0.0 < 2.0.0
325func expandWildcardVersion(parts [][]string) ([][]string, error) {
326	var expandedParts [][]string
327	for _, p := range parts {
328		var newParts []string
329		for _, ap := range p {
330			if strings.Contains(ap, "x") {
331				opStr, vStr, err := splitComparatorVersion(ap)
332				if err != nil {
333					return nil, err
334				}
335
336				versionWildcardType := getWildcardType(vStr)
337				flatVersion := createVersionFromWildcard(vStr)
338
339				var resultOperator string
340				var shouldIncrementVersion bool
341				switch opStr {
342				case ">":
343					resultOperator = ">="
344					shouldIncrementVersion = true
345				case ">=":
346					resultOperator = ">="
347				case "<":
348					resultOperator = "<"
349				case "<=":
350					resultOperator = "<"
351					shouldIncrementVersion = true
352				case "", "=", "==":
353					newParts = append(newParts, ">="+flatVersion)
354					resultOperator = "<"
355					shouldIncrementVersion = true
356				case "!=", "!":
357					newParts = append(newParts, "<"+flatVersion)
358					resultOperator = ">="
359					shouldIncrementVersion = true
360				}
361
362				var resultVersion string
363				if shouldIncrementVersion {
364					switch versionWildcardType {
365					case patchWildcard:
366						resultVersion, _ = incrementMinorVersion(flatVersion)
367					case minorWildcard:
368						resultVersion, _ = incrementMajorVersion(flatVersion)
369					}
370				} else {
371					resultVersion = flatVersion
372				}
373
374				ap = resultOperator + resultVersion
375			}
376			newParts = append(newParts, ap)
377		}
378		expandedParts = append(expandedParts, newParts)
379	}
380
381	return expandedParts, nil
382}
383
384func parseComparator(s string) comparator {
385	switch s {
386	case "==":
387		fallthrough
388	case "":
389		fallthrough
390	case "=":
391		return compEQ
392	case ">":
393		return compGT
394	case ">=":
395		return compGE
396	case "<":
397		return compLT
398	case "<=":
399		return compLE
400	case "!":
401		fallthrough
402	case "!=":
403		return compNE
404	}
405
406	return nil
407}
408
409// MustParseRange is like ParseRange but panics if the range cannot be parsed.
410func MustParseRange(s string) Range {
411	r, err := ParseRange(s)
412	if err != nil {
413		panic(`semver: ParseRange(` + s + `): ` + err.Error())
414	}
415	return r
416}
417