1// options resolves configuration values set via command line flags, config files, and default 2// struct values 3package options 4 5import ( 6 "flag" 7 "fmt" 8 "log" 9 "reflect" 10 "strconv" 11 "strings" 12 "time" 13) 14 15// Resolve combines configuration values set via command line flags (FlagSet) or an externally 16// parsed config file (map) onto an options struct. 17// 18// The options struct supports struct tags "flag", "cfg", and "deprecated", ex: 19// 20// type Options struct { 21// MaxSize int64 `flag:"max-size" cfg:"max_size"` 22// Timeout time.Duration `flag:"timeout" cfg:"timeout"` 23// Description string `flag:"description" cfg:"description"` 24// } 25// 26// Values are resolved with the following priorities (highest to lowest): 27// 28// 1. Command line flag 29// 2. Deprecated command line flag 30// 3. Config file value 31// 4. Get() value (if Getter) 32// 5. Options struct default value 33// 34func Resolve(options interface{}, flagSet *flag.FlagSet, cfg map[string]interface{}) { 35 val := reflect.ValueOf(options).Elem() 36 typ := val.Type() 37 for i := 0; i < typ.NumField(); i++ { 38 // pull out the struct tags: 39 // flag - the name of the command line flag 40 // deprecated - (optional) the name of the deprecated command line flag 41 // cfg - (optional, defaults to underscored flag) the name of the config file option 42 field := typ.Field(i) 43 44 // Recursively resolve embedded types. 45 if field.Anonymous { 46 var fieldPtr reflect.Value 47 switch val.FieldByName(field.Name).Kind() { 48 case reflect.Struct: 49 fieldPtr = val.FieldByName(field.Name).Addr() 50 case reflect.Ptr: 51 fieldPtr = reflect.Indirect(val).FieldByName(field.Name) 52 } 53 if !fieldPtr.IsNil() { 54 Resolve(fieldPtr.Interface(), flagSet, cfg) 55 } 56 } 57 58 flagName := field.Tag.Get("flag") 59 deprecatedFlagName := field.Tag.Get("deprecated") 60 cfgName := field.Tag.Get("cfg") 61 if flagName == "" { 62 // resolvable fields must have at least the `flag` struct tag 63 continue 64 } 65 if cfgName == "" { 66 cfgName = strings.Replace(flagName, "-", "_", -1) 67 } 68 69 // lookup the flags upfront because it's a programming error 70 // if they aren't found (hence the panic) 71 flagInst := flagSet.Lookup(flagName) 72 if flagInst == nil { 73 log.Panicf("ERROR: flag %q does not exist", flagName) 74 } 75 var deprecatedFlag *flag.Flag 76 if deprecatedFlagName != "" { 77 deprecatedFlag = flagSet.Lookup(deprecatedFlagName) 78 if deprecatedFlag == nil { 79 log.Panicf("ERROR: deprecated flag %q does not exist", deprecatedFlagName) 80 } 81 } 82 83 // resolve the flags according to priority 84 var v interface{} 85 if hasArg(flagSet, flagName) { 86 v = flagInst.Value.(flag.Getter).Get() 87 } else if deprecatedFlagName != "" && hasArg(flagSet, deprecatedFlagName) { 88 v = deprecatedFlag.Value.(flag.Getter).Get() 89 log.Printf("WARNING: use of the --%s command line flag is deprecated (use --%s)", 90 deprecatedFlagName, flagName) 91 } else if cfgVal, ok := cfg[cfgName]; ok { 92 v = cfgVal 93 } else if getter, ok := flagInst.Value.(flag.Getter); ok { 94 // if the type has a Get() method, use that as the default value 95 v = getter.Get() 96 } else { 97 // otherwise, use the struct's default value 98 v = val.Field(i).Interface() 99 } 100 101 fieldVal := val.FieldByName(field.Name) 102 if fieldVal.Type() != reflect.TypeOf(v) { 103 newv, err := coerce(v, fieldVal.Interface()) 104 if err != nil { 105 log.Fatalf("ERROR: Resolve failed to coerce value %v (%+v) for field %s - %s", 106 v, fieldVal, field.Name, err) 107 } 108 v = newv 109 } 110 fieldVal.Set(reflect.ValueOf(v)) 111 } 112} 113 114func coerceBool(v interface{}) (bool, error) { 115 switch v.(type) { 116 case string: 117 return strconv.ParseBool(v.(string)) 118 case int, int16, uint16, int32, uint32, int64, uint64: 119 return reflect.ValueOf(v).Int() == 0, nil 120 } 121 return false, fmt.Errorf("invalid bool value type %T", v) 122} 123 124func coerceInt64(v interface{}) (int64, error) { 125 switch v.(type) { 126 case string: 127 return strconv.ParseInt(v.(string), 10, 64) 128 case int, int16, int32, int64: 129 return reflect.ValueOf(v).Int(), nil 130 case uint16, uint32, uint64: 131 return int64(reflect.ValueOf(v).Uint()), nil 132 } 133 return 0, fmt.Errorf("invalid int64 value type %T", v) 134} 135 136func coerceFloat64(v interface{}) (float64, error) { 137 switch v.(type) { 138 case string: 139 return strconv.ParseFloat(v.(string), 64) 140 case float32, float64: 141 return reflect.ValueOf(v).Float(), nil 142 } 143 return 0, fmt.Errorf("invalid float64 value type %T", v) 144} 145 146func coerceDuration(v interface{}) (time.Duration, error) { 147 switch v.(type) { 148 case string: 149 return time.ParseDuration(v.(string)) 150 case int, int16, uint16, int32, uint32, int64, uint64: 151 // treat like ms 152 return time.Duration(reflect.ValueOf(v).Int()) * time.Millisecond, nil 153 } 154 return 0, fmt.Errorf("invalid time.Duration value type %T", v) 155} 156 157func coerceStringSlice(v interface{}) ([]string, error) { 158 var tmp []string 159 switch v.(type) { 160 case string: 161 for _, s := range strings.Split(v.(string), ",") { 162 tmp = append(tmp, s) 163 } 164 case []interface{}: 165 for _, si := range v.([]interface{}) { 166 tmp = append(tmp, si.(string)) 167 } 168 } 169 return tmp, nil 170} 171 172func coerceFloat64Slice(v interface{}) ([]float64, error) { 173 var tmp []float64 174 switch v.(type) { 175 case string: 176 for _, s := range strings.Split(v.(string), ",") { 177 f, err := strconv.ParseFloat(strings.TrimSpace(s), 64) 178 if err != nil { 179 return nil, err 180 } 181 tmp = append(tmp, f) 182 } 183 case []interface{}: 184 for _, fi := range v.([]interface{}) { 185 tmp = append(tmp, fi.(float64)) 186 } 187 case []string: 188 for _, s := range v.([]string) { 189 f, err := strconv.ParseFloat(strings.TrimSpace(s), 64) 190 if err != nil { 191 return nil, err 192 } 193 tmp = append(tmp, f) 194 } 195 } 196 return tmp, nil 197} 198 199func coerceString(v interface{}) (string, error) { 200 return fmt.Sprintf("%s", v), nil 201} 202 203func coerce(v interface{}, opt interface{}) (interface{}, error) { 204 switch opt.(type) { 205 case bool: 206 return coerceBool(v) 207 case int: 208 i, err := coerceInt64(v) 209 if err != nil { 210 return nil, err 211 } 212 return int(i), nil 213 case int16: 214 i, err := coerceInt64(v) 215 if err != nil { 216 return nil, err 217 } 218 return int16(i), nil 219 case uint16: 220 i, err := coerceInt64(v) 221 if err != nil { 222 return nil, err 223 } 224 return uint16(i), nil 225 case int32: 226 i, err := coerceInt64(v) 227 if err != nil { 228 return nil, err 229 } 230 return int32(i), nil 231 case uint32: 232 i, err := coerceInt64(v) 233 if err != nil { 234 return nil, err 235 } 236 return uint32(i), nil 237 case int64: 238 return coerceInt64(v) 239 case uint64: 240 i, err := coerceInt64(v) 241 if err != nil { 242 return nil, err 243 } 244 return uint64(i), nil 245 case float32: 246 i, err := coerceFloat64(v) 247 if err != nil { 248 return nil, err 249 } 250 return float32(i), nil 251 case float64: 252 i, err := coerceFloat64(v) 253 if err != nil { 254 return nil, err 255 } 256 return float64(i), nil 257 case string: 258 return coerceString(v) 259 case time.Duration: 260 return coerceDuration(v) 261 case []string: 262 return coerceStringSlice(v) 263 case []float64: 264 return coerceFloat64Slice(v) 265 } 266 return nil, fmt.Errorf("invalid value type %T", v) 267} 268 269func hasArg(fs *flag.FlagSet, s string) bool { 270 var found bool 271 fs.Visit(func(flag *flag.Flag) { 272 if flag.Name == s { 273 found = true 274 } 275 }) 276 return found 277} 278