1package options
2
3import (
4	"errors"
5	"fmt"
6	"io/ioutil"
7	"reflect"
8	"strings"
9
10	"github.com/ghodss/yaml"
11	"github.com/mitchellh/mapstructure"
12	"github.com/spf13/pflag"
13	"github.com/spf13/viper"
14)
15
16// Load reads in the config file at the path given, then merges in environment
17// variables (prefixed with `OAUTH2_PROXY`) and finally merges in flags from the flagSet.
18// If a config value is unset and the flag has a non-zero value default, this default will be used.
19// Eg. A field defined:
20//    FooBar `cfg:"foo_bar" flag:"foo-bar"`
21// Can be set in the config file as `foo_bar="baz"`, in the environment as `OAUTH2_PROXY_FOO_BAR=baz`,
22// or via the command line flag `--foo-bar=baz`.
23func Load(configFileName string, flagSet *pflag.FlagSet, into interface{}) error {
24	v := viper.New()
25	v.SetConfigFile(configFileName)
26	v.SetConfigType("toml") // Config is in toml format
27	v.SetEnvPrefix("OAUTH2_PROXY")
28	v.AutomaticEnv()
29	v.SetTypeByDefaultValue(true)
30
31	if configFileName != "" {
32		err := v.ReadInConfig()
33		if err != nil {
34			return fmt.Errorf("unable to load config file: %w", err)
35		}
36	}
37
38	err := registerFlags(v, "", flagSet, into)
39	if err != nil {
40		// This should only happen if there is a programming error
41		return fmt.Errorf("unable to register flags: %w", err)
42	}
43
44	// UnmarhsalExact will return an error if the config includes options that are
45	// not mapped to felds of the into struct
46	err = v.UnmarshalExact(into, decodeFromCfgTag)
47	if err != nil {
48		return fmt.Errorf("error unmarshalling config: %w", err)
49	}
50
51	return nil
52}
53
54// registerFlags uses `cfg` and `flag` tags to associate flags in the flagSet
55// to the fields in the options interface provided.
56// Each exported field in the options must have a `cfg` tag otherwise an error will occur.
57// - For fields, set `cfg` and `flag` so that `flag` is the name of the flag associated to this config option
58// - For exported fields that are not user facing, set the `cfg` to `,internal`
59// - For structs containing user facing fields, set the `cfg` to `,squash`
60func registerFlags(v *viper.Viper, prefix string, flagSet *pflag.FlagSet, options interface{}) error {
61	val := reflect.ValueOf(options)
62	var typ reflect.Type
63	if val.Kind() == reflect.Ptr {
64		typ = val.Elem().Type()
65	} else {
66		typ = val.Type()
67	}
68
69	for i := 0; i < typ.NumField(); i++ {
70		// pull out the struct tags:
71		//    flag - the name of the command line flag
72		//    cfg - the name of the config file option
73		field := typ.Field(i)
74		fieldV := reflect.Indirect(val).Field(i)
75		fieldName := strings.Join([]string{prefix, field.Name}, ".")
76
77		cfgName := field.Tag.Get("cfg")
78		if cfgName == ",internal" {
79			// Public but internal types that should not be exposed to users, skip them
80			continue
81		}
82
83		if isUnexported(field.Name) {
84			// Unexported fields cannot be set by a user, so won't have tags or flags, skip them
85			continue
86		}
87
88		if field.Type.Kind() == reflect.Struct {
89			if cfgName != ",squash" {
90				return fmt.Errorf("field %q does not have required cfg tag: `,squash`", fieldName)
91			}
92			err := registerFlags(v, fieldName, flagSet, fieldV.Interface())
93			if err != nil {
94				return err
95			}
96			continue
97		}
98
99		flagName := field.Tag.Get("flag")
100		if flagName == "" || cfgName == "" {
101			return fmt.Errorf("field %q does not have required tags (cfg, flag)", fieldName)
102		}
103
104		if flagSet == nil {
105			return fmt.Errorf("flagset cannot be nil")
106		}
107
108		f := flagSet.Lookup(flagName)
109		if f == nil {
110			return fmt.Errorf("field %q does not have a registered flag", flagName)
111		}
112		err := v.BindPFlag(cfgName, f)
113		if err != nil {
114			return fmt.Errorf("error binding flag for field %q: %w", fieldName, err)
115		}
116	}
117
118	return nil
119}
120
121// decodeFromCfgTag sets the Viper decoder to read the names from the `cfg` tag
122// on each struct entry.
123func decodeFromCfgTag(c *mapstructure.DecoderConfig) {
124	c.TagName = "cfg"
125}
126
127// isUnexported checks if a field name starts with a lowercase letter and therefore
128// if it is unexported.
129func isUnexported(name string) bool {
130	if len(name) == 0 {
131		// This should never happen
132		panic("field name has len 0")
133	}
134
135	first := string(name[0])
136	return first == strings.ToLower(first)
137}
138
139// LoadYAML will load a YAML based configuration file into the options interface provided.
140func LoadYAML(configFileName string, into interface{}) error {
141	v := viper.New()
142	v.SetConfigFile(configFileName)
143	v.SetConfigType("yaml")
144	v.SetTypeByDefaultValue(true)
145
146	if configFileName == "" {
147		return errors.New("no configuration file provided")
148	}
149
150	data, err := ioutil.ReadFile(configFileName)
151	if err != nil {
152		return fmt.Errorf("unable to load config file: %w", err)
153	}
154
155	// UnmarshalStrict will return an error if the config includes options that are
156	// not mapped to felds of the into struct
157	if err := yaml.UnmarshalStrict(data, into, yaml.DisallowUnknownFields); err != nil {
158		return fmt.Errorf("error unmarshalling config: %w", err)
159	}
160
161	return nil
162}
163