1// Copyright (C) 2019 Storj Labs, Inc.
2// See LICENSE for copying information.
3
4package cfgstruct
5
6import (
7	"fmt"
8	"os"
9	"path/filepath"
10	"reflect"
11	"strconv"
12	"strings"
13	"time"
14
15	"github.com/spf13/cobra"
16	"github.com/spf13/pflag"
17	"go.uber.org/zap"
18
19	"storj.io/private/version"
20)
21
22const (
23	// AnySource is a source annotation for config values that can come from
24	// a flag or file.
25	AnySource = "any"
26
27	// FlagSource is a source annotation for config values that just come from
28	// flags (i.e. are never persisted to file).
29	FlagSource = "flag"
30
31	// BasicHelpAnnotationName is the name of the annotation used to indicate
32	// a flag should be included in basic usage/help.
33	BasicHelpAnnotationName = "basic-help"
34)
35
36var (
37	allSources = []string{
38		AnySource,
39		FlagSource,
40	}
41)
42
43// BindOpt is an option for the Bind method.
44type BindOpt struct {
45	isDev   *bool
46	isTest  *bool
47	isSetup *bool
48	varfn   func(vars map[string]confVar)
49}
50
51// ConfDir sets variables for default options called $CONFDIR.
52func ConfDir(path string) BindOpt {
53	return ConfigVar("CONFDIR", filepath.Clean(os.ExpandEnv(path)))
54}
55
56// IdentityDir sets a variable for the default option called $IDENTITYDIR.
57func IdentityDir(path string) BindOpt {
58	return ConfigVar("IDENTITYDIR", filepath.Clean(os.ExpandEnv(path)))
59}
60
61// ConfigVar sets a variable for the default option called name.
62func ConfigVar(name, val string) BindOpt {
63	name = strings.ToUpper(name)
64	return BindOpt{varfn: func(vars map[string]confVar) {
65		vars[name] = confVar{val: val, nested: false}
66	}}
67}
68
69// SetupMode issues the bind in a mode where it does not ignore fields with the
70// `setup:"true"` tag.
71func SetupMode() BindOpt {
72	setup := true
73	return BindOpt{isSetup: &setup}
74}
75
76// UseDevDefaults forces the bind call to use development defaults unless
77// something else is provided as a subsequent option.
78// Without a specific defaults setting, Bind will default to determining which
79// defaults to use based on version.Build.Release.
80func UseDevDefaults() BindOpt {
81	dev := true
82	test := false
83	return BindOpt{isDev: &dev, isTest: &test}
84}
85
86// UseReleaseDefaults forces the bind call to use release defaults unless
87// something else is provided as a subsequent option.
88// Without a specific defaults setting, Bind will default to determining which
89// defaults to use based on version.Build.Release.
90func UseReleaseDefaults() BindOpt {
91	dev := false
92	test := false
93	return BindOpt{isDev: &dev, isTest: &test}
94}
95
96// UseTestDefaults forces the bind call to use test defaults unless
97// something else is provided as a subsequent option.
98// Without a specific defaults setting, Bind will default to determining which
99// defaults to use based on version.Build.Release.
100func UseTestDefaults() BindOpt {
101	dev := false
102	test := true
103	return BindOpt{isDev: &dev, isTest: &test}
104}
105
106type confVar struct {
107	val    string
108	nested bool
109}
110
111// Bind sets flags on a FlagSet that match the configuration struct
112// 'config'. This works by traversing the config struct using the 'reflect'
113// package.
114func Bind(flags FlagSet, config interface{}, opts ...BindOpt) {
115	bind(flags, config, opts...)
116}
117
118func bind(flags FlagSet, config interface{}, opts ...BindOpt) {
119	ptrtype := reflect.TypeOf(config)
120	if ptrtype.Kind() != reflect.Ptr {
121		panic(fmt.Sprintf("invalid config type: %#v. Expecting pointer to struct.", config))
122	}
123	isDev := !version.Build.Release
124	isTest := false
125	setupCommand := false
126	vars := map[string]confVar{}
127	for _, opt := range opts {
128		if opt.varfn != nil {
129			opt.varfn(vars)
130		}
131		if opt.isDev != nil {
132			isDev = *opt.isDev
133		}
134		if opt.isTest != nil {
135			isTest = *opt.isTest
136		}
137		if opt.isSetup != nil {
138			setupCommand = *opt.isSetup
139		}
140	}
141
142	bindConfig(flags, "", reflect.ValueOf(config).Elem(), vars, setupCommand, false, isDev, isTest)
143}
144
145func bindConfig(flags FlagSet, prefix string, val reflect.Value, vars map[string]confVar, setupCommand, setupStruct bool, isDev, isTest bool) {
146	if val.Kind() != reflect.Struct {
147		panic(fmt.Sprintf("invalid config type: %#v. Expecting struct.", val.Interface()))
148	}
149	typ := val.Type()
150	resolvedVars := make(map[string]string, len(vars))
151	{
152		structpath := strings.ReplaceAll(prefix, ".", string(filepath.Separator))
153		for k, v := range vars {
154			if !v.nested {
155				resolvedVars[k] = v.val
156				continue
157			}
158			resolvedVars[k] = filepath.Join(v.val, structpath)
159		}
160	}
161
162	for i := 0; i < typ.NumField(); i++ {
163		field := typ.Field(i)
164		fieldval := val.Field(i)
165		flagname := hyphenate(snakeCase(field.Name))
166
167		if field.Tag.Get("noprefix") != "true" {
168			flagname = prefix + flagname
169		}
170
171		onlyForSetup := (field.Tag.Get("setup") == "true") || setupStruct
172		// ignore setup params for non setup commands
173		if !setupCommand && onlyForSetup {
174			continue
175		}
176
177		if !fieldval.CanAddr() {
178			panic(fmt.Sprintf("cannot addr field %s in %s", field.Name, typ))
179		}
180
181		fieldref := fieldval.Addr()
182		if !fieldref.CanInterface() {
183			panic(fmt.Sprintf("cannot get interface of field %s in %s", field.Name, typ))
184		}
185
186		fieldaddr := fieldref.Interface()
187		if fieldvalue, ok := fieldaddr.(pflag.Value); ok {
188			help := field.Tag.Get("help")
189			def := getDefault(field.Tag, isTest, isDev, flagname)
190
191			if field.Tag.Get("internal") == "true" {
192				if def != "" {
193					panic(fmt.Sprintf("unapplicable default value set for internal flag: %s", flagname))
194				}
195				continue
196			}
197
198			err := fieldvalue.Set(def)
199			if err != nil {
200				panic(fmt.Sprintf("invalid default value for %s: %#v, %v", flagname, def, err))
201			}
202			flags.Var(fieldvalue, flagname, help)
203
204			markHidden := false
205			if onlyForSetup {
206				SetBoolAnnotation(flags, flagname, "setup", true)
207			}
208			if field.Tag.Get("user") == "true" {
209				SetBoolAnnotation(flags, flagname, "user", true)
210			}
211			if field.Tag.Get("hidden") == "true" {
212				markHidden = true
213				SetBoolAnnotation(flags, flagname, "hidden", true)
214			}
215			if field.Tag.Get("deprecated") == "true" {
216				markHidden = true
217				SetBoolAnnotation(flags, flagname, "deprecated", true)
218			}
219			if source := field.Tag.Get("source"); source != "" {
220				setSourceAnnotation(flags, flagname, source)
221			}
222			if markHidden {
223				err := flags.MarkHidden(flagname)
224				if err != nil {
225					panic(fmt.Sprintf("mark hidden failed %s: %v", flagname, err))
226				}
227			}
228			continue
229		}
230
231		switch field.Type.Kind() {
232		case reflect.Struct:
233			if field.Anonymous {
234				bindConfig(flags, prefix, fieldval, vars, setupCommand, onlyForSetup, isDev, isTest)
235			} else {
236				bindConfig(flags, flagname+".", fieldval, vars, setupCommand, onlyForSetup, isDev, isTest)
237			}
238		case reflect.Array:
239			digits := len(fmt.Sprint(fieldval.Len()))
240			for j := 0; j < fieldval.Len(); j++ {
241				padding := strings.Repeat("0", digits-len(fmt.Sprint(j)))
242				bindConfig(flags, fmt.Sprintf("%s.%s%d.", flagname, padding, j), fieldval.Index(j), vars, setupCommand, onlyForSetup, isDev, isTest)
243			}
244		default:
245			help := field.Tag.Get("help")
246			def := getDefault(field.Tag, isTest, isDev, flagname)
247
248			if field.Tag.Get("internal") == "true" {
249				if def != "" {
250					panic(fmt.Sprintf("unapplicable default value set for internal flag: %s", flagname))
251				}
252				continue
253			}
254
255			def = expand(resolvedVars, def)
256
257			fieldaddr := fieldval.Addr().Interface()
258			check := func(err error) {
259				if err != nil {
260					panic(fmt.Sprintf("invalid default value for %s: %#v", flagname, def))
261				}
262			}
263			switch field.Type {
264			case reflect.TypeOf(int(0)):
265				val, err := strconv.ParseInt(def, 0, strconv.IntSize)
266				check(err)
267				flags.IntVar(fieldaddr.(*int), flagname, int(val), help)
268			case reflect.TypeOf(int64(0)):
269				val, err := strconv.ParseInt(def, 0, 64)
270				check(err)
271				flags.Int64Var(fieldaddr.(*int64), flagname, val, help)
272			case reflect.TypeOf(uint(0)):
273				val, err := strconv.ParseUint(def, 0, strconv.IntSize)
274				check(err)
275				flags.UintVar(fieldaddr.(*uint), flagname, uint(val), help)
276			case reflect.TypeOf(uint64(0)):
277				val, err := strconv.ParseUint(def, 0, 64)
278				check(err)
279				flags.Uint64Var(fieldaddr.(*uint64), flagname, val, help)
280			case reflect.TypeOf(time.Duration(0)):
281				val, err := time.ParseDuration(def)
282				check(err)
283				flags.DurationVar(fieldaddr.(*time.Duration), flagname, val, help)
284			case reflect.TypeOf(float64(0)):
285				val, err := strconv.ParseFloat(def, 64)
286				check(err)
287				flags.Float64Var(fieldaddr.(*float64), flagname, val, help)
288			case reflect.TypeOf(string("")):
289				if field.Tag.Get("path") == "true" {
290					// NB: conventionally unix path separators are used in default values
291					def = filepath.FromSlash(def)
292				}
293				flags.StringVar(fieldaddr.(*string), flagname, def, help)
294			case reflect.TypeOf(bool(false)):
295				val, err := strconv.ParseBool(def)
296				check(err)
297				flags.BoolVar(fieldaddr.(*bool), flagname, val, help)
298			case reflect.TypeOf([]string(nil)):
299				flags.StringArrayVar(fieldaddr.(*[]string), flagname, nil, help)
300			default:
301				panic(fmt.Sprintf("invalid field type: %s", field.Type.String()))
302			}
303			if onlyForSetup {
304				SetBoolAnnotation(flags, flagname, "setup", true)
305			}
306			if field.Tag.Get("user") == "true" {
307				SetBoolAnnotation(flags, flagname, "user", true)
308			}
309			if field.Tag.Get(BasicHelpAnnotationName) == "true" {
310				SetBoolAnnotation(flags, flagname, BasicHelpAnnotationName, true)
311			}
312
313			markHidden := false
314			if field.Tag.Get("hidden") == "true" {
315				markHidden = true
316				SetBoolAnnotation(flags, flagname, "hidden", true)
317			}
318			if field.Tag.Get("deprecated") == "true" {
319				markHidden = true
320				SetBoolAnnotation(flags, flagname, "deprecated", true)
321			}
322			if source := field.Tag.Get("source"); source != "" {
323				setSourceAnnotation(flags, flagname, source)
324			}
325			if markHidden {
326				err := flags.MarkHidden(flagname)
327				if err != nil {
328					panic(fmt.Sprintf("mark hidden failed %s: %v", flagname, err))
329				}
330			}
331		}
332	}
333}
334
335func getDefault(tag reflect.StructTag, isTest, isDev bool, flagname string) string {
336	var order []string
337	var opposites []string
338	if isTest {
339		order = []string{"testDefault", "devDefault", "default"}
340		opposites = []string{"releaseDefault"}
341	} else if isDev {
342		order = []string{"devDefault", "default"}
343		opposites = []string{"releaseDefault", "testDefault"}
344	} else {
345		order = []string{"releaseDefault", "default"}
346		opposites = []string{"devDefault", "testDefault"}
347	}
348
349	for _, name := range order {
350		if val, ok := tag.Lookup(name); ok {
351			return val
352		}
353	}
354
355	for _, name := range opposites {
356		if _, ok := tag.Lookup(name); ok {
357			panic(fmt.Sprintf("%q missing but %q defined for %v", order[0], name, flagname))
358		}
359	}
360
361	return ""
362}
363
364func setSourceAnnotation(flagset interface{}, name, source string) {
365	switch source {
366	case AnySource:
367	case FlagSource:
368	default:
369		panic(fmt.Sprintf("invalid source annotation %q for %s: must be one of %q", source, name, allSources))
370	}
371
372	setStringAnnotation(flagset, name, "source", source)
373}
374
375func setStringAnnotation(flagset interface{}, name, key, value string) {
376	flags, ok := flagset.(*pflag.FlagSet)
377	if !ok {
378		return
379	}
380
381	err := flags.SetAnnotation(name, key, []string{value})
382	if err != nil {
383		panic(fmt.Sprintf("unable to set %s annotation for %s: %v", key, name, err))
384	}
385}
386
387// SetBoolAnnotation sets an annotation (if it can) on flagset with a value of []string{"true|false"}.
388func SetBoolAnnotation(flagset interface{}, name, key string, value bool) {
389	flags, ok := flagset.(*pflag.FlagSet)
390	if !ok {
391		return
392	}
393
394	err := flags.SetAnnotation(name, key, []string{strconv.FormatBool(value)})
395	if err != nil {
396		panic(fmt.Sprintf("unable to set %s annotation for %s: %v", key, name, err))
397	}
398}
399
400func expand(vars map[string]string, val string) string {
401	return os.Expand(val, func(key string) string { return vars[key] })
402}
403
404// FindConfigDirParam returns '--config-dir' param from os.Args (if exists).
405func FindConfigDirParam() string {
406	return FindFlagEarly("config-dir")
407}
408
409// FindIdentityDirParam returns '--identity-dir' param from os.Args (if exists).
410func FindIdentityDirParam() string {
411	return FindFlagEarly("identity-dir")
412}
413
414// FindDefaultsParam returns '--defaults' param from os.Args (if it exists).
415func FindDefaultsParam() string {
416	return FindFlagEarly("defaults")
417}
418
419// FindFlagEarly retrieves the value of a flag before `flag.Parse` has been called.
420func FindFlagEarly(flagName string) string {
421	// workaround to have early access to 'dir' param
422	for i, arg := range os.Args {
423		if strings.HasPrefix(arg, fmt.Sprintf("--%s=", flagName)) {
424			return strings.TrimPrefix(arg, fmt.Sprintf("--%s=", flagName))
425		} else if arg == fmt.Sprintf("--%s", flagName) && i < len(os.Args)-1 {
426			return os.Args[i+1]
427		}
428	}
429	return ""
430}
431
432// SetupFlag sets up flags that are needed before `flag.Parse` has been called.
433func SetupFlag(log *zap.Logger, cmd *cobra.Command, dest *string, name, value, usage string) {
434	if foundValue := FindFlagEarly(name); foundValue != "" {
435		value = foundValue
436	}
437	cmd.PersistentFlags().StringVar(dest, name, value, usage)
438	if cmd.PersistentFlags().SetAnnotation(name, "setup", []string{"true"}) != nil {
439		log.Error("Failed to set 'setup' annotation", zap.String("Flag", name))
440	}
441}
442
443// DefaultsType returns the type of defaults (release/dev) this binary should use.
444func DefaultsType() string {
445	// define a flag so that the flag parsing system will be happy.
446	defaults := strings.ToLower(FindDefaultsParam())
447	if defaults != "" {
448		return defaults
449	}
450	if version.Build.Release {
451		return "release"
452	}
453	return "dev"
454}
455
456// DefaultsFlag sets up the defaults=dev/release flag options, which is needed
457// before `flag.Parse` has been called.
458func DefaultsFlag(cmd *cobra.Command) BindOpt {
459	// define a flag so that the flag parsing system will be happy.
460	defaults := DefaultsType()
461
462	// we're actually going to ignore this flag entirely and parse the commandline
463	// arguments early instead
464	_ = cmd.PersistentFlags().String("defaults", defaults,
465		"determines which set of configuration defaults to use. can either be 'dev' or 'release'")
466	setSourceAnnotation(cmd.PersistentFlags(), "defaults", FlagSource)
467
468	switch defaults {
469	case "dev":
470		return UseDevDefaults()
471	case "release":
472		return UseReleaseDefaults()
473	case "test":
474		return UseTestDefaults()
475	default:
476		panic(fmt.Sprintf("unsupported defaults value %q", FindDefaultsParam()))
477	}
478}
479