1package options 2 3import ( 4 "errors" 5 "fmt" 6 "io/ioutil" 7 "reflect" 8 "strings" 9 10 "github.com/ghodss/yaml" 11 "github.com/mitchellh/mapstructure" 12 "github.com/spf13/pflag" 13 "github.com/spf13/viper" 14) 15 16// Load reads in the config file at the path given, then merges in environment 17// variables (prefixed with `OAUTH2_PROXY`) and finally merges in flags from the flagSet. 18// If a config value is unset and the flag has a non-zero value default, this default will be used. 19// Eg. A field defined: 20// FooBar `cfg:"foo_bar" flag:"foo-bar"` 21// Can be set in the config file as `foo_bar="baz"`, in the environment as `OAUTH2_PROXY_FOO_BAR=baz`, 22// or via the command line flag `--foo-bar=baz`. 23func Load(configFileName string, flagSet *pflag.FlagSet, into interface{}) error { 24 v := viper.New() 25 v.SetConfigFile(configFileName) 26 v.SetConfigType("toml") // Config is in toml format 27 v.SetEnvPrefix("OAUTH2_PROXY") 28 v.AutomaticEnv() 29 v.SetTypeByDefaultValue(true) 30 31 if configFileName != "" { 32 err := v.ReadInConfig() 33 if err != nil { 34 return fmt.Errorf("unable to load config file: %w", err) 35 } 36 } 37 38 err := registerFlags(v, "", flagSet, into) 39 if err != nil { 40 // This should only happen if there is a programming error 41 return fmt.Errorf("unable to register flags: %w", err) 42 } 43 44 // UnmarhsalExact will return an error if the config includes options that are 45 // not mapped to felds of the into struct 46 err = v.UnmarshalExact(into, decodeFromCfgTag) 47 if err != nil { 48 return fmt.Errorf("error unmarshalling config: %w", err) 49 } 50 51 return nil 52} 53 54// registerFlags uses `cfg` and `flag` tags to associate flags in the flagSet 55// to the fields in the options interface provided. 56// Each exported field in the options must have a `cfg` tag otherwise an error will occur. 57// - For fields, set `cfg` and `flag` so that `flag` is the name of the flag associated to this config option 58// - For exported fields that are not user facing, set the `cfg` to `,internal` 59// - For structs containing user facing fields, set the `cfg` to `,squash` 60func registerFlags(v *viper.Viper, prefix string, flagSet *pflag.FlagSet, options interface{}) error { 61 val := reflect.ValueOf(options) 62 var typ reflect.Type 63 if val.Kind() == reflect.Ptr { 64 typ = val.Elem().Type() 65 } else { 66 typ = val.Type() 67 } 68 69 for i := 0; i < typ.NumField(); i++ { 70 // pull out the struct tags: 71 // flag - the name of the command line flag 72 // cfg - the name of the config file option 73 field := typ.Field(i) 74 fieldV := reflect.Indirect(val).Field(i) 75 fieldName := strings.Join([]string{prefix, field.Name}, ".") 76 77 cfgName := field.Tag.Get("cfg") 78 if cfgName == ",internal" { 79 // Public but internal types that should not be exposed to users, skip them 80 continue 81 } 82 83 if isUnexported(field.Name) { 84 // Unexported fields cannot be set by a user, so won't have tags or flags, skip them 85 continue 86 } 87 88 if field.Type.Kind() == reflect.Struct { 89 if cfgName != ",squash" { 90 return fmt.Errorf("field %q does not have required cfg tag: `,squash`", fieldName) 91 } 92 err := registerFlags(v, fieldName, flagSet, fieldV.Interface()) 93 if err != nil { 94 return err 95 } 96 continue 97 } 98 99 flagName := field.Tag.Get("flag") 100 if flagName == "" || cfgName == "" { 101 return fmt.Errorf("field %q does not have required tags (cfg, flag)", fieldName) 102 } 103 104 if flagSet == nil { 105 return fmt.Errorf("flagset cannot be nil") 106 } 107 108 f := flagSet.Lookup(flagName) 109 if f == nil { 110 return fmt.Errorf("field %q does not have a registered flag", flagName) 111 } 112 err := v.BindPFlag(cfgName, f) 113 if err != nil { 114 return fmt.Errorf("error binding flag for field %q: %w", fieldName, err) 115 } 116 } 117 118 return nil 119} 120 121// decodeFromCfgTag sets the Viper decoder to read the names from the `cfg` tag 122// on each struct entry. 123func decodeFromCfgTag(c *mapstructure.DecoderConfig) { 124 c.TagName = "cfg" 125} 126 127// isUnexported checks if a field name starts with a lowercase letter and therefore 128// if it is unexported. 129func isUnexported(name string) bool { 130 if len(name) == 0 { 131 // This should never happen 132 panic("field name has len 0") 133 } 134 135 first := string(name[0]) 136 return first == strings.ToLower(first) 137} 138 139// LoadYAML will load a YAML based configuration file into the options interface provided. 140func LoadYAML(configFileName string, into interface{}) error { 141 v := viper.New() 142 v.SetConfigFile(configFileName) 143 v.SetConfigType("yaml") 144 v.SetTypeByDefaultValue(true) 145 146 if configFileName == "" { 147 return errors.New("no configuration file provided") 148 } 149 150 data, err := ioutil.ReadFile(configFileName) 151 if err != nil { 152 return fmt.Errorf("unable to load config file: %w", err) 153 } 154 155 // UnmarshalStrict will return an error if the config includes options that are 156 // not mapped to felds of the into struct 157 if err := yaml.UnmarshalStrict(data, into, yaml.DisallowUnknownFields); err != nil { 158 return fmt.Errorf("error unmarshalling config: %w", err) 159 } 160 161 return nil 162} 163