1package flags
2
3import (
4	"fmt"
5	"os"
6	"path/filepath"
7	"reflect"
8	"sort"
9	"strconv"
10	"time"
11
12	"github.com/mitchellh/mapstructure"
13)
14
15// TODO (slackpad) - Trying out a different pattern here for config handling.
16// These classes support the flag.Value interface but work in a manner where
17// we can tell if they have been set. This lets us work with an all-pointer
18// config structure and merge it in a clean-ish way. If this ends up being a
19// good pattern we should pull this out into a reusable library.
20
21// ConfigDecodeHook should be passed to mapstructure in order to decode into
22// the *Value objects here.
23var ConfigDecodeHook = mapstructure.ComposeDecodeHookFunc(
24	BoolToBoolValueFunc(),
25	StringToDurationValueFunc(),
26	StringToStringValueFunc(),
27	Float64ToUintValueFunc(),
28)
29
30// BoolValue provides a flag value that's aware if it has been set.
31type BoolValue struct {
32	v *bool
33}
34
35// IsBoolFlag is an optional method of the flag.Value
36// interface which marks this value as boolean when
37// the return value is true. See flag.Value for details.
38func (b *BoolValue) IsBoolFlag() bool {
39	return true
40}
41
42// Merge will overlay this value if it has been set.
43func (b *BoolValue) Merge(onto *bool) {
44	if b.v != nil {
45		*onto = *(b.v)
46	}
47}
48
49// Set implements the flag.Value interface.
50func (b *BoolValue) Set(v string) error {
51	if b.v == nil {
52		b.v = new(bool)
53	}
54	var err error
55	*(b.v), err = strconv.ParseBool(v)
56	return err
57}
58
59// String implements the flag.Value interface.
60func (b *BoolValue) String() string {
61	var current bool
62	if b.v != nil {
63		current = *(b.v)
64	}
65	return fmt.Sprintf("%v", current)
66}
67
68// BoolToBoolValueFunc is a mapstructure hook that looks for an incoming bool
69// mapped to a BoolValue and does the translation.
70func BoolToBoolValueFunc() mapstructure.DecodeHookFunc {
71	return func(
72		f reflect.Type,
73		t reflect.Type,
74		data interface{}) (interface{}, error) {
75		if f.Kind() != reflect.Bool {
76			return data, nil
77		}
78
79		val := BoolValue{}
80		if t != reflect.TypeOf(val) {
81			return data, nil
82		}
83
84		val.v = new(bool)
85		*(val.v) = data.(bool)
86		return val, nil
87	}
88}
89
90// DurationValue provides a flag value that's aware if it has been set.
91type DurationValue struct {
92	v *time.Duration
93}
94
95// Merge will overlay this value if it has been set.
96func (d *DurationValue) Merge(onto *time.Duration) {
97	if d.v != nil {
98		*onto = *(d.v)
99	}
100}
101
102// Set implements the flag.Value interface.
103func (d *DurationValue) Set(v string) error {
104	if d.v == nil {
105		d.v = new(time.Duration)
106	}
107	var err error
108	*(d.v), err = time.ParseDuration(v)
109	return err
110}
111
112// String implements the flag.Value interface.
113func (d *DurationValue) String() string {
114	var current time.Duration
115	if d.v != nil {
116		current = *(d.v)
117	}
118	return current.String()
119}
120
121// StringToDurationValueFunc is a mapstructure hook that looks for an incoming
122// string mapped to a DurationValue and does the translation.
123func StringToDurationValueFunc() mapstructure.DecodeHookFunc {
124	return func(
125		f reflect.Type,
126		t reflect.Type,
127		data interface{}) (interface{}, error) {
128		if f.Kind() != reflect.String {
129			return data, nil
130		}
131
132		val := DurationValue{}
133		if t != reflect.TypeOf(val) {
134			return data, nil
135		}
136		if err := val.Set(data.(string)); err != nil {
137			return nil, err
138		}
139		return val, nil
140	}
141}
142
143// StringValue provides a flag value that's aware if it has been set.
144type StringValue struct {
145	v *string
146}
147
148// Merge will overlay this value if it has been set.
149func (s *StringValue) Merge(onto *string) {
150	if s.v != nil {
151		*onto = *(s.v)
152	}
153}
154
155// Set implements the flag.Value interface.
156func (s *StringValue) Set(v string) error {
157	if s.v == nil {
158		s.v = new(string)
159	}
160	*(s.v) = v
161	return nil
162}
163
164// String implements the flag.Value interface.
165func (s *StringValue) String() string {
166	var current string
167	if s.v != nil {
168		current = *(s.v)
169	}
170	return current
171}
172
173// StringToStringValueFunc is a mapstructure hook that looks for an incoming
174// string mapped to a StringValue and does the translation.
175func StringToStringValueFunc() mapstructure.DecodeHookFunc {
176	return func(
177		f reflect.Type,
178		t reflect.Type,
179		data interface{}) (interface{}, error) {
180		if f.Kind() != reflect.String {
181			return data, nil
182		}
183
184		val := StringValue{}
185		if t != reflect.TypeOf(val) {
186			return data, nil
187		}
188		val.v = new(string)
189		*(val.v) = data.(string)
190		return val, nil
191	}
192}
193
194// UintValue provides a flag value that's aware if it has been set.
195type UintValue struct {
196	v *uint
197}
198
199// Merge will overlay this value if it has been set.
200func (u *UintValue) Merge(onto *uint) {
201	if u.v != nil {
202		*onto = *(u.v)
203	}
204}
205
206// Set implements the flag.Value interface.
207func (u *UintValue) Set(v string) error {
208	if u.v == nil {
209		u.v = new(uint)
210	}
211	parsed, err := strconv.ParseUint(v, 0, 64)
212	*(u.v) = (uint)(parsed)
213	return err
214}
215
216// String implements the flag.Value interface.
217func (u *UintValue) String() string {
218	var current uint
219	if u.v != nil {
220		current = *(u.v)
221	}
222	return fmt.Sprintf("%v", current)
223}
224
225// Float64ToUintValueFunc is a mapstructure hook that looks for an incoming
226// float64 mapped to a UintValue and does the translation.
227func Float64ToUintValueFunc() mapstructure.DecodeHookFunc {
228	return func(
229		f reflect.Type,
230		t reflect.Type,
231		data interface{}) (interface{}, error) {
232		if f.Kind() != reflect.Float64 {
233			return data, nil
234		}
235
236		val := UintValue{}
237		if t != reflect.TypeOf(val) {
238			return data, nil
239		}
240
241		fv := data.(float64)
242		if fv < 0 {
243			return nil, fmt.Errorf("value cannot be negative")
244		}
245
246		// The standard guarantees at least this, and this is fine for
247		// values we expect to use in configs vs. being fancy with the
248		// machine's size for uint.
249		if fv > (1<<32 - 1) {
250			return nil, fmt.Errorf("value is too large")
251		}
252
253		val.v = new(uint)
254		*(val.v) = (uint)(fv)
255		return val, nil
256	}
257}
258
259// VisitFn is a callback that gets a chance to visit each file found during a
260// traversal with visit().
261type VisitFn func(path string) error
262
263// Visit will call the visitor function on the path if it's a file, or for each
264// file in the path if it's a directory. Directories will not be recursed into,
265// and files in the directory will be visited in alphabetical order.
266func Visit(path string, visitor VisitFn) error {
267	f, err := os.Open(path)
268	if err != nil {
269		return fmt.Errorf("error reading %q: %v", path, err)
270	}
271	defer f.Close()
272
273	fi, err := f.Stat()
274	if err != nil {
275		return fmt.Errorf("error checking %q: %v", path, err)
276	}
277
278	if !fi.IsDir() {
279		if err := visitor(path); err != nil {
280			return fmt.Errorf("error in %q: %v", path, err)
281		}
282		return nil
283	}
284
285	contents, err := f.Readdir(-1)
286	if err != nil {
287		return fmt.Errorf("error listing %q: %v", path, err)
288	}
289
290	sort.Sort(dirEnts(contents))
291	for _, fi := range contents {
292		if fi.IsDir() {
293			continue
294		}
295
296		fullPath := filepath.Join(path, fi.Name())
297		if err := visitor(fullPath); err != nil {
298			return fmt.Errorf("error in %q: %v", fullPath, err)
299		}
300	}
301
302	return nil
303}
304
305// dirEnts applies sort.Interface to directory entries for sorting by name.
306type dirEnts []os.FileInfo
307
308func (d dirEnts) Len() int           { return len(d) }
309func (d dirEnts) Less(i, j int) bool { return d[i].Name() < d[j].Name() }
310func (d dirEnts) Swap(i, j int)      { d[i], d[j] = d[j], d[i] }
311