1// Copyright (C) 2019 Storj Labs, Inc. 2// See LICENSE for copying information. 3 4package cfgstruct 5 6import ( 7 "fmt" 8 "os" 9 "path/filepath" 10 "reflect" 11 "strconv" 12 "strings" 13 "time" 14 15 "github.com/spf13/cobra" 16 "github.com/spf13/pflag" 17 "go.uber.org/zap" 18 19 "storj.io/private/version" 20) 21 22const ( 23 // AnySource is a source annotation for config values that can come from 24 // a flag or file. 25 AnySource = "any" 26 27 // FlagSource is a source annotation for config values that just come from 28 // flags (i.e. are never persisted to file). 29 FlagSource = "flag" 30 31 // BasicHelpAnnotationName is the name of the annotation used to indicate 32 // a flag should be included in basic usage/help. 33 BasicHelpAnnotationName = "basic-help" 34) 35 36var ( 37 allSources = []string{ 38 AnySource, 39 FlagSource, 40 } 41) 42 43// BindOpt is an option for the Bind method. 44type BindOpt struct { 45 isDev *bool 46 isTest *bool 47 isSetup *bool 48 varfn func(vars map[string]confVar) 49} 50 51// ConfDir sets variables for default options called $CONFDIR. 52func ConfDir(path string) BindOpt { 53 return ConfigVar("CONFDIR", filepath.Clean(os.ExpandEnv(path))) 54} 55 56// IdentityDir sets a variable for the default option called $IDENTITYDIR. 57func IdentityDir(path string) BindOpt { 58 return ConfigVar("IDENTITYDIR", filepath.Clean(os.ExpandEnv(path))) 59} 60 61// ConfigVar sets a variable for the default option called name. 62func ConfigVar(name, val string) BindOpt { 63 name = strings.ToUpper(name) 64 return BindOpt{varfn: func(vars map[string]confVar) { 65 vars[name] = confVar{val: val, nested: false} 66 }} 67} 68 69// SetupMode issues the bind in a mode where it does not ignore fields with the 70// `setup:"true"` tag. 71func SetupMode() BindOpt { 72 setup := true 73 return BindOpt{isSetup: &setup} 74} 75 76// UseDevDefaults forces the bind call to use development defaults unless 77// something else is provided as a subsequent option. 78// Without a specific defaults setting, Bind will default to determining which 79// defaults to use based on version.Build.Release. 80func UseDevDefaults() BindOpt { 81 dev := true 82 test := false 83 return BindOpt{isDev: &dev, isTest: &test} 84} 85 86// UseReleaseDefaults forces the bind call to use release defaults unless 87// something else is provided as a subsequent option. 88// Without a specific defaults setting, Bind will default to determining which 89// defaults to use based on version.Build.Release. 90func UseReleaseDefaults() BindOpt { 91 dev := false 92 test := false 93 return BindOpt{isDev: &dev, isTest: &test} 94} 95 96// UseTestDefaults forces the bind call to use test defaults unless 97// something else is provided as a subsequent option. 98// Without a specific defaults setting, Bind will default to determining which 99// defaults to use based on version.Build.Release. 100func UseTestDefaults() BindOpt { 101 dev := false 102 test := true 103 return BindOpt{isDev: &dev, isTest: &test} 104} 105 106type confVar struct { 107 val string 108 nested bool 109} 110 111// Bind sets flags on a FlagSet that match the configuration struct 112// 'config'. This works by traversing the config struct using the 'reflect' 113// package. 114func Bind(flags FlagSet, config interface{}, opts ...BindOpt) { 115 bind(flags, config, opts...) 116} 117 118func bind(flags FlagSet, config interface{}, opts ...BindOpt) { 119 ptrtype := reflect.TypeOf(config) 120 if ptrtype.Kind() != reflect.Ptr { 121 panic(fmt.Sprintf("invalid config type: %#v. Expecting pointer to struct.", config)) 122 } 123 isDev := !version.Build.Release 124 isTest := false 125 setupCommand := false 126 vars := map[string]confVar{} 127 for _, opt := range opts { 128 if opt.varfn != nil { 129 opt.varfn(vars) 130 } 131 if opt.isDev != nil { 132 isDev = *opt.isDev 133 } 134 if opt.isTest != nil { 135 isTest = *opt.isTest 136 } 137 if opt.isSetup != nil { 138 setupCommand = *opt.isSetup 139 } 140 } 141 142 bindConfig(flags, "", reflect.ValueOf(config).Elem(), vars, setupCommand, false, isDev, isTest) 143} 144 145func bindConfig(flags FlagSet, prefix string, val reflect.Value, vars map[string]confVar, setupCommand, setupStruct bool, isDev, isTest bool) { 146 if val.Kind() != reflect.Struct { 147 panic(fmt.Sprintf("invalid config type: %#v. Expecting struct.", val.Interface())) 148 } 149 typ := val.Type() 150 resolvedVars := make(map[string]string, len(vars)) 151 { 152 structpath := strings.ReplaceAll(prefix, ".", string(filepath.Separator)) 153 for k, v := range vars { 154 if !v.nested { 155 resolvedVars[k] = v.val 156 continue 157 } 158 resolvedVars[k] = filepath.Join(v.val, structpath) 159 } 160 } 161 162 for i := 0; i < typ.NumField(); i++ { 163 field := typ.Field(i) 164 fieldval := val.Field(i) 165 flagname := hyphenate(snakeCase(field.Name)) 166 167 if field.Tag.Get("noprefix") != "true" { 168 flagname = prefix + flagname 169 } 170 171 onlyForSetup := (field.Tag.Get("setup") == "true") || setupStruct 172 // ignore setup params for non setup commands 173 if !setupCommand && onlyForSetup { 174 continue 175 } 176 177 if !fieldval.CanAddr() { 178 panic(fmt.Sprintf("cannot addr field %s in %s", field.Name, typ)) 179 } 180 181 fieldref := fieldval.Addr() 182 if !fieldref.CanInterface() { 183 panic(fmt.Sprintf("cannot get interface of field %s in %s", field.Name, typ)) 184 } 185 186 fieldaddr := fieldref.Interface() 187 if fieldvalue, ok := fieldaddr.(pflag.Value); ok { 188 help := field.Tag.Get("help") 189 def := getDefault(field.Tag, isTest, isDev, flagname) 190 191 if field.Tag.Get("internal") == "true" { 192 if def != "" { 193 panic(fmt.Sprintf("unapplicable default value set for internal flag: %s", flagname)) 194 } 195 continue 196 } 197 198 err := fieldvalue.Set(def) 199 if err != nil { 200 panic(fmt.Sprintf("invalid default value for %s: %#v, %v", flagname, def, err)) 201 } 202 flags.Var(fieldvalue, flagname, help) 203 204 markHidden := false 205 if onlyForSetup { 206 SetBoolAnnotation(flags, flagname, "setup", true) 207 } 208 if field.Tag.Get("user") == "true" { 209 SetBoolAnnotation(flags, flagname, "user", true) 210 } 211 if field.Tag.Get("hidden") == "true" { 212 markHidden = true 213 SetBoolAnnotation(flags, flagname, "hidden", true) 214 } 215 if field.Tag.Get("deprecated") == "true" { 216 markHidden = true 217 SetBoolAnnotation(flags, flagname, "deprecated", true) 218 } 219 if source := field.Tag.Get("source"); source != "" { 220 setSourceAnnotation(flags, flagname, source) 221 } 222 if markHidden { 223 err := flags.MarkHidden(flagname) 224 if err != nil { 225 panic(fmt.Sprintf("mark hidden failed %s: %v", flagname, err)) 226 } 227 } 228 continue 229 } 230 231 switch field.Type.Kind() { 232 case reflect.Struct: 233 if field.Anonymous { 234 bindConfig(flags, prefix, fieldval, vars, setupCommand, onlyForSetup, isDev, isTest) 235 } else { 236 bindConfig(flags, flagname+".", fieldval, vars, setupCommand, onlyForSetup, isDev, isTest) 237 } 238 case reflect.Array: 239 digits := len(fmt.Sprint(fieldval.Len())) 240 for j := 0; j < fieldval.Len(); j++ { 241 padding := strings.Repeat("0", digits-len(fmt.Sprint(j))) 242 bindConfig(flags, fmt.Sprintf("%s.%s%d.", flagname, padding, j), fieldval.Index(j), vars, setupCommand, onlyForSetup, isDev, isTest) 243 } 244 default: 245 help := field.Tag.Get("help") 246 def := getDefault(field.Tag, isTest, isDev, flagname) 247 248 if field.Tag.Get("internal") == "true" { 249 if def != "" { 250 panic(fmt.Sprintf("unapplicable default value set for internal flag: %s", flagname)) 251 } 252 continue 253 } 254 255 def = expand(resolvedVars, def) 256 257 fieldaddr := fieldval.Addr().Interface() 258 check := func(err error) { 259 if err != nil { 260 panic(fmt.Sprintf("invalid default value for %s: %#v", flagname, def)) 261 } 262 } 263 switch field.Type { 264 case reflect.TypeOf(int(0)): 265 val, err := strconv.ParseInt(def, 0, strconv.IntSize) 266 check(err) 267 flags.IntVar(fieldaddr.(*int), flagname, int(val), help) 268 case reflect.TypeOf(int64(0)): 269 val, err := strconv.ParseInt(def, 0, 64) 270 check(err) 271 flags.Int64Var(fieldaddr.(*int64), flagname, val, help) 272 case reflect.TypeOf(uint(0)): 273 val, err := strconv.ParseUint(def, 0, strconv.IntSize) 274 check(err) 275 flags.UintVar(fieldaddr.(*uint), flagname, uint(val), help) 276 case reflect.TypeOf(uint64(0)): 277 val, err := strconv.ParseUint(def, 0, 64) 278 check(err) 279 flags.Uint64Var(fieldaddr.(*uint64), flagname, val, help) 280 case reflect.TypeOf(time.Duration(0)): 281 val, err := time.ParseDuration(def) 282 check(err) 283 flags.DurationVar(fieldaddr.(*time.Duration), flagname, val, help) 284 case reflect.TypeOf(float64(0)): 285 val, err := strconv.ParseFloat(def, 64) 286 check(err) 287 flags.Float64Var(fieldaddr.(*float64), flagname, val, help) 288 case reflect.TypeOf(string("")): 289 if field.Tag.Get("path") == "true" { 290 // NB: conventionally unix path separators are used in default values 291 def = filepath.FromSlash(def) 292 } 293 flags.StringVar(fieldaddr.(*string), flagname, def, help) 294 case reflect.TypeOf(bool(false)): 295 val, err := strconv.ParseBool(def) 296 check(err) 297 flags.BoolVar(fieldaddr.(*bool), flagname, val, help) 298 case reflect.TypeOf([]string(nil)): 299 flags.StringArrayVar(fieldaddr.(*[]string), flagname, nil, help) 300 default: 301 panic(fmt.Sprintf("invalid field type: %s", field.Type.String())) 302 } 303 if onlyForSetup { 304 SetBoolAnnotation(flags, flagname, "setup", true) 305 } 306 if field.Tag.Get("user") == "true" { 307 SetBoolAnnotation(flags, flagname, "user", true) 308 } 309 if field.Tag.Get(BasicHelpAnnotationName) == "true" { 310 SetBoolAnnotation(flags, flagname, BasicHelpAnnotationName, true) 311 } 312 313 markHidden := false 314 if field.Tag.Get("hidden") == "true" { 315 markHidden = true 316 SetBoolAnnotation(flags, flagname, "hidden", true) 317 } 318 if field.Tag.Get("deprecated") == "true" { 319 markHidden = true 320 SetBoolAnnotation(flags, flagname, "deprecated", true) 321 } 322 if source := field.Tag.Get("source"); source != "" { 323 setSourceAnnotation(flags, flagname, source) 324 } 325 if markHidden { 326 err := flags.MarkHidden(flagname) 327 if err != nil { 328 panic(fmt.Sprintf("mark hidden failed %s: %v", flagname, err)) 329 } 330 } 331 } 332 } 333} 334 335func getDefault(tag reflect.StructTag, isTest, isDev bool, flagname string) string { 336 var order []string 337 var opposites []string 338 if isTest { 339 order = []string{"testDefault", "devDefault", "default"} 340 opposites = []string{"releaseDefault"} 341 } else if isDev { 342 order = []string{"devDefault", "default"} 343 opposites = []string{"releaseDefault", "testDefault"} 344 } else { 345 order = []string{"releaseDefault", "default"} 346 opposites = []string{"devDefault", "testDefault"} 347 } 348 349 for _, name := range order { 350 if val, ok := tag.Lookup(name); ok { 351 return val 352 } 353 } 354 355 for _, name := range opposites { 356 if _, ok := tag.Lookup(name); ok { 357 panic(fmt.Sprintf("%q missing but %q defined for %v", order[0], name, flagname)) 358 } 359 } 360 361 return "" 362} 363 364func setSourceAnnotation(flagset interface{}, name, source string) { 365 switch source { 366 case AnySource: 367 case FlagSource: 368 default: 369 panic(fmt.Sprintf("invalid source annotation %q for %s: must be one of %q", source, name, allSources)) 370 } 371 372 setStringAnnotation(flagset, name, "source", source) 373} 374 375func setStringAnnotation(flagset interface{}, name, key, value string) { 376 flags, ok := flagset.(*pflag.FlagSet) 377 if !ok { 378 return 379 } 380 381 err := flags.SetAnnotation(name, key, []string{value}) 382 if err != nil { 383 panic(fmt.Sprintf("unable to set %s annotation for %s: %v", key, name, err)) 384 } 385} 386 387// SetBoolAnnotation sets an annotation (if it can) on flagset with a value of []string{"true|false"}. 388func SetBoolAnnotation(flagset interface{}, name, key string, value bool) { 389 flags, ok := flagset.(*pflag.FlagSet) 390 if !ok { 391 return 392 } 393 394 err := flags.SetAnnotation(name, key, []string{strconv.FormatBool(value)}) 395 if err != nil { 396 panic(fmt.Sprintf("unable to set %s annotation for %s: %v", key, name, err)) 397 } 398} 399 400func expand(vars map[string]string, val string) string { 401 return os.Expand(val, func(key string) string { return vars[key] }) 402} 403 404// FindConfigDirParam returns '--config-dir' param from os.Args (if exists). 405func FindConfigDirParam() string { 406 return FindFlagEarly("config-dir") 407} 408 409// FindIdentityDirParam returns '--identity-dir' param from os.Args (if exists). 410func FindIdentityDirParam() string { 411 return FindFlagEarly("identity-dir") 412} 413 414// FindDefaultsParam returns '--defaults' param from os.Args (if it exists). 415func FindDefaultsParam() string { 416 return FindFlagEarly("defaults") 417} 418 419// FindFlagEarly retrieves the value of a flag before `flag.Parse` has been called. 420func FindFlagEarly(flagName string) string { 421 // workaround to have early access to 'dir' param 422 for i, arg := range os.Args { 423 if strings.HasPrefix(arg, fmt.Sprintf("--%s=", flagName)) { 424 return strings.TrimPrefix(arg, fmt.Sprintf("--%s=", flagName)) 425 } else if arg == fmt.Sprintf("--%s", flagName) && i < len(os.Args)-1 { 426 return os.Args[i+1] 427 } 428 } 429 return "" 430} 431 432// SetupFlag sets up flags that are needed before `flag.Parse` has been called. 433func SetupFlag(log *zap.Logger, cmd *cobra.Command, dest *string, name, value, usage string) { 434 if foundValue := FindFlagEarly(name); foundValue != "" { 435 value = foundValue 436 } 437 cmd.PersistentFlags().StringVar(dest, name, value, usage) 438 if cmd.PersistentFlags().SetAnnotation(name, "setup", []string{"true"}) != nil { 439 log.Error("Failed to set 'setup' annotation", zap.String("Flag", name)) 440 } 441} 442 443// DefaultsType returns the type of defaults (release/dev) this binary should use. 444func DefaultsType() string { 445 // define a flag so that the flag parsing system will be happy. 446 defaults := strings.ToLower(FindDefaultsParam()) 447 if defaults != "" { 448 return defaults 449 } 450 if version.Build.Release { 451 return "release" 452 } 453 return "dev" 454} 455 456// DefaultsFlag sets up the defaults=dev/release flag options, which is needed 457// before `flag.Parse` has been called. 458func DefaultsFlag(cmd *cobra.Command) BindOpt { 459 // define a flag so that the flag parsing system will be happy. 460 defaults := DefaultsType() 461 462 // we're actually going to ignore this flag entirely and parse the commandline 463 // arguments early instead 464 _ = cmd.PersistentFlags().String("defaults", defaults, 465 "determines which set of configuration defaults to use. can either be 'dev' or 'release'") 466 setSourceAnnotation(cmd.PersistentFlags(), "defaults", FlagSource) 467 468 switch defaults { 469 case "dev": 470 return UseDevDefaults() 471 case "release": 472 return UseReleaseDefaults() 473 case "test": 474 return UseTestDefaults() 475 default: 476 panic(fmt.Sprintf("unsupported defaults value %q", FindDefaultsParam())) 477 } 478} 479