1package cli
2
3import (
4	"flag"
5	"fmt"
6	"strconv"
7	"strings"
8)
9
10// IntSlice is an opaque type for []int to satisfy flag.Value and flag.Getter
11type IntSlice []int
12
13// Set parses the value into an integer and appends it to the list of values
14func (f *IntSlice) Set(value string) error {
15	tmp, err := strconv.Atoi(value)
16	if err != nil {
17		return err
18	}
19	*f = append(*f, tmp)
20	return nil
21}
22
23// String returns a readable representation of this value (for usage defaults)
24func (f *IntSlice) String() string {
25	slice := make([]string, len(*f))
26	for i, v := range *f {
27		slice[i] = strconv.Itoa(v)
28	}
29
30	return strings.Join(slice, ",")
31}
32
33// Value returns the slice of ints set by this flag
34func (f *IntSlice) Value() []int {
35	return *f
36}
37
38// Get returns the slice of ints set by this flag
39func (f *IntSlice) Get() interface{} {
40	return *f
41}
42
43// IntSliceFlag is a flag with type *IntSlice
44type IntSliceFlag struct {
45	Name     string
46	Usage    string
47	EnvVar   string
48	FilePath string
49	Required bool
50	Hidden   bool
51	Value    *IntSlice
52}
53
54// String returns a readable representation of this value
55// (for usage defaults)
56func (f IntSliceFlag) String() string {
57	return FlagStringer(f)
58}
59
60// GetName returns the name of the flag
61func (f IntSliceFlag) GetName() string {
62	return f.Name
63}
64
65// IsRequired returns whether or not the flag is required
66func (f IntSliceFlag) IsRequired() bool {
67	return f.Required
68}
69
70// TakesValue returns true of the flag takes a value, otherwise false
71func (f IntSliceFlag) TakesValue() bool {
72	return true
73}
74
75// GetUsage returns the usage string for the flag
76func (f IntSliceFlag) GetUsage() string {
77	return f.Usage
78}
79
80// GetValue returns the flags value as string representation and an empty
81// string if the flag takes no value at all.
82func (f IntSliceFlag) GetValue() string {
83	if f.Value != nil {
84		return f.Value.String()
85	}
86	return ""
87}
88
89// Apply populates the flag given the flag set and environment
90// Ignores errors
91func (f IntSliceFlag) Apply(set *flag.FlagSet) {
92	_ = f.ApplyWithError(set)
93}
94
95// ApplyWithError populates the flag given the flag set and environment
96func (f IntSliceFlag) ApplyWithError(set *flag.FlagSet) error {
97	if envVal, ok := flagFromFileEnv(f.FilePath, f.EnvVar); ok {
98		newVal := &IntSlice{}
99		for _, s := range strings.Split(envVal, ",") {
100			s = strings.TrimSpace(s)
101			if err := newVal.Set(s); err != nil {
102				return fmt.Errorf("could not parse %s as int slice value for flag %s: %s", envVal, f.Name, err)
103			}
104		}
105		if f.Value == nil {
106			f.Value = newVal
107		} else {
108			*f.Value = *newVal
109		}
110	}
111
112	eachName(f.Name, func(name string) {
113		if f.Value == nil {
114			f.Value = &IntSlice{}
115		}
116		set.Var(f.Value, name, f.Usage)
117	})
118
119	return nil
120}
121
122// IntSlice looks up the value of a local IntSliceFlag, returns
123// nil if not found
124func (c *Context) IntSlice(name string) []int {
125	return lookupIntSlice(name, c.flagSet)
126}
127
128// GlobalIntSlice looks up the value of a global IntSliceFlag, returns
129// nil if not found
130func (c *Context) GlobalIntSlice(name string) []int {
131	if fs := lookupGlobalFlagSet(name, c); fs != nil {
132		return lookupIntSlice(name, fs)
133	}
134	return nil
135}
136
137func lookupIntSlice(name string, set *flag.FlagSet) []int {
138	f := set.Lookup(name)
139	if f != nil {
140		value, ok := f.Value.(*IntSlice)
141		if !ok {
142			return nil
143		}
144		// extract the slice from asserted value
145		slice := value.Value()
146
147		// extract default value from the flag
148		var defaultVal []int
149		for _, v := range strings.Split(f.DefValue, ",") {
150			if v != "" {
151				intValue, err := strconv.Atoi(v)
152				if err != nil {
153					panic(err)
154				}
155				defaultVal = append(defaultVal, intValue)
156			}
157		}
158		// if the current value is not equal to the default value
159		// remove the default values from the flag
160		if !isIntSliceEqual(slice, defaultVal) {
161			for _, v := range defaultVal {
162				slice = removeFromIntSlice(slice, v)
163			}
164		}
165		return slice
166	}
167	return nil
168}
169
170func removeFromIntSlice(slice []int, val int) []int {
171	for i, v := range slice {
172		if v == val {
173			ret := append([]int{}, slice[:i]...)
174			ret = append(ret, slice[i+1:]...)
175			return ret
176		}
177	}
178	return slice
179}
180
181func isIntSliceEqual(newValue, defaultValue []int) bool {
182	// If one is nil, the other must also be nil.
183	if (newValue == nil) != (defaultValue == nil) {
184		return false
185	}
186
187	if len(newValue) != len(defaultValue) {
188		return false
189	}
190
191	for i, v := range newValue {
192		if v != defaultValue[i] {
193			return false
194		}
195	}
196
197	return true
198}
199