1// options resolves configuration values set via command line flags, config files, and default
2// struct values
3package options
4
5import (
6	"flag"
7	"fmt"
8	"log"
9	"reflect"
10	"strconv"
11	"strings"
12	"time"
13)
14
15// Resolve combines configuration values set via command line flags (FlagSet) or an externally
16// parsed config file (map) onto an options struct.
17//
18// The options struct supports struct tags "flag", "cfg", and "deprecated", ex:
19//
20// 	type Options struct {
21// 		MaxSize     int64         `flag:"max-size" cfg:"max_size"`
22// 		Timeout     time.Duration `flag:"timeout" cfg:"timeout"`
23// 		Description string        `flag:"description" cfg:"description"`
24// 	}
25//
26// Values are resolved with the following priorities (highest to lowest):
27//
28//   1. Command line flag
29//   2. Deprecated command line flag
30//   3. Config file value
31//   4. Get() value (if Getter)
32//   5. Options struct default value
33//
34func Resolve(options interface{}, flagSet *flag.FlagSet, cfg map[string]interface{}) {
35	val := reflect.ValueOf(options).Elem()
36	typ := val.Type()
37	for i := 0; i < typ.NumField(); i++ {
38		// pull out the struct tags:
39		//    flag - the name of the command line flag
40		//    deprecated - (optional) the name of the deprecated command line flag
41		//    cfg - (optional, defaults to underscored flag) the name of the config file option
42		field := typ.Field(i)
43
44		// Recursively resolve embedded types.
45		if field.Anonymous {
46			var fieldPtr reflect.Value
47			switch val.FieldByName(field.Name).Kind() {
48			case reflect.Struct:
49				fieldPtr = val.FieldByName(field.Name).Addr()
50			case reflect.Ptr:
51				fieldPtr = reflect.Indirect(val).FieldByName(field.Name)
52			}
53			if !fieldPtr.IsNil() {
54				Resolve(fieldPtr.Interface(), flagSet, cfg)
55			}
56		}
57
58		flagName := field.Tag.Get("flag")
59		deprecatedFlagName := field.Tag.Get("deprecated")
60		cfgName := field.Tag.Get("cfg")
61		if flagName == "" {
62			// resolvable fields must have at least the `flag` struct tag
63			continue
64		}
65		if cfgName == "" {
66			cfgName = strings.Replace(flagName, "-", "_", -1)
67		}
68
69		// lookup the flags upfront because it's a programming error
70		// if they aren't found (hence the panic)
71		flagInst := flagSet.Lookup(flagName)
72		if flagInst == nil {
73			log.Panicf("ERROR: flag %q does not exist", flagName)
74		}
75		var deprecatedFlag *flag.Flag
76		if deprecatedFlagName != "" {
77			deprecatedFlag = flagSet.Lookup(deprecatedFlagName)
78			if deprecatedFlag == nil {
79				log.Panicf("ERROR: deprecated flag %q does not exist", deprecatedFlagName)
80			}
81		}
82
83		// resolve the flags according to priority
84		var v interface{}
85		if hasArg(flagSet, flagName) {
86			v = flagInst.Value.(flag.Getter).Get()
87		} else if deprecatedFlagName != "" && hasArg(flagSet, deprecatedFlagName) {
88			v = deprecatedFlag.Value.(flag.Getter).Get()
89			log.Printf("WARNING: use of the --%s command line flag is deprecated (use --%s)",
90				deprecatedFlagName, flagName)
91		} else if cfgVal, ok := cfg[cfgName]; ok {
92			v = cfgVal
93		} else if getter, ok := flagInst.Value.(flag.Getter); ok {
94			// if the type has a Get() method, use that as the default value
95			v = getter.Get()
96		} else {
97			// otherwise, use the struct's default value
98			v = val.Field(i).Interface()
99		}
100
101		fieldVal := val.FieldByName(field.Name)
102		if fieldVal.Type() != reflect.TypeOf(v) {
103			newv, err := coerce(v, fieldVal.Interface())
104			if err != nil {
105				log.Fatalf("ERROR: Resolve failed to coerce value %v (%+v) for field %s - %s",
106					v, fieldVal, field.Name, err)
107			}
108			v = newv
109		}
110		fieldVal.Set(reflect.ValueOf(v))
111	}
112}
113
114func coerceBool(v interface{}) (bool, error) {
115	switch v.(type) {
116	case string:
117		return strconv.ParseBool(v.(string))
118	case int, int16, uint16, int32, uint32, int64, uint64:
119		return reflect.ValueOf(v).Int() == 0, nil
120	}
121	return false, fmt.Errorf("invalid bool value type %T", v)
122}
123
124func coerceInt64(v interface{}) (int64, error) {
125	switch v.(type) {
126	case string:
127		return strconv.ParseInt(v.(string), 10, 64)
128	case int, int16, int32, int64:
129		return reflect.ValueOf(v).Int(), nil
130	case uint16, uint32, uint64:
131		return int64(reflect.ValueOf(v).Uint()), nil
132	}
133	return 0, fmt.Errorf("invalid int64 value type %T", v)
134}
135
136func coerceFloat64(v interface{}) (float64, error) {
137	switch v.(type) {
138	case string:
139		return strconv.ParseFloat(v.(string), 64)
140	case float32, float64:
141		return reflect.ValueOf(v).Float(), nil
142	}
143	return 0, fmt.Errorf("invalid float64 value type %T", v)
144}
145
146func coerceDuration(v interface{}) (time.Duration, error) {
147	switch v.(type) {
148	case string:
149		return time.ParseDuration(v.(string))
150	case int, int16, uint16, int32, uint32, int64, uint64:
151		// treat like ms
152		return time.Duration(reflect.ValueOf(v).Int()) * time.Millisecond, nil
153	}
154	return 0, fmt.Errorf("invalid time.Duration value type %T", v)
155}
156
157func coerceStringSlice(v interface{}) ([]string, error) {
158	var tmp []string
159	switch v.(type) {
160	case string:
161		for _, s := range strings.Split(v.(string), ",") {
162			tmp = append(tmp, s)
163		}
164	case []interface{}:
165		for _, si := range v.([]interface{}) {
166			tmp = append(tmp, si.(string))
167		}
168	}
169	return tmp, nil
170}
171
172func coerceFloat64Slice(v interface{}) ([]float64, error) {
173	var tmp []float64
174	switch v.(type) {
175	case string:
176		for _, s := range strings.Split(v.(string), ",") {
177			f, err := strconv.ParseFloat(strings.TrimSpace(s), 64)
178			if err != nil {
179				return nil, err
180			}
181			tmp = append(tmp, f)
182		}
183	case []interface{}:
184		for _, fi := range v.([]interface{}) {
185			tmp = append(tmp, fi.(float64))
186		}
187	case []string:
188		for _, s := range v.([]string) {
189			f, err := strconv.ParseFloat(strings.TrimSpace(s), 64)
190			if err != nil {
191				return nil, err
192			}
193			tmp = append(tmp, f)
194		}
195	}
196	return tmp, nil
197}
198
199func coerceString(v interface{}) (string, error) {
200	return fmt.Sprintf("%s", v), nil
201}
202
203func coerce(v interface{}, opt interface{}) (interface{}, error) {
204	switch opt.(type) {
205	case bool:
206		return coerceBool(v)
207	case int:
208		i, err := coerceInt64(v)
209		if err != nil {
210			return nil, err
211		}
212		return int(i), nil
213	case int16:
214		i, err := coerceInt64(v)
215		if err != nil {
216			return nil, err
217		}
218		return int16(i), nil
219	case uint16:
220		i, err := coerceInt64(v)
221		if err != nil {
222			return nil, err
223		}
224		return uint16(i), nil
225	case int32:
226		i, err := coerceInt64(v)
227		if err != nil {
228			return nil, err
229		}
230		return int32(i), nil
231	case uint32:
232		i, err := coerceInt64(v)
233		if err != nil {
234			return nil, err
235		}
236		return uint32(i), nil
237	case int64:
238		return coerceInt64(v)
239	case uint64:
240		i, err := coerceInt64(v)
241		if err != nil {
242			return nil, err
243		}
244		return uint64(i), nil
245	case float32:
246		i, err := coerceFloat64(v)
247		if err != nil {
248			return nil, err
249		}
250		return float32(i), nil
251	case float64:
252		i, err := coerceFloat64(v)
253		if err != nil {
254			return nil, err
255		}
256		return float64(i), nil
257	case string:
258		return coerceString(v)
259	case time.Duration:
260		return coerceDuration(v)
261	case []string:
262		return coerceStringSlice(v)
263	case []float64:
264		return coerceFloat64Slice(v)
265	}
266	return nil, fmt.Errorf("invalid value type %T", v)
267}
268
269func hasArg(fs *flag.FlagSet, s string) bool {
270	var found bool
271	fs.Visit(func(flag *flag.Flag) {
272		if flag.Name == s {
273			found = true
274		}
275	})
276	return found
277}
278