1// Copyright 2012 Jesse van den Kieboom. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package flags
6
7import (
8	"errors"
9	"reflect"
10	"strings"
11	"unicode/utf8"
12)
13
14// ErrNotPointerToStruct indicates that a provided data container is not
15// a pointer to a struct. Only pointers to structs are valid data containers
16// for options.
17var ErrNotPointerToStruct = errors.New("provided data is not a pointer to struct")
18
19// Group represents an option group. Option groups can be used to logically
20// group options together under a description. Groups are only used to provide
21// more structure to options both for the user (as displayed in the help message)
22// and for you, since groups can be nested.
23type Group struct {
24	// A short description of the group. The
25	// short description is primarily used in the built-in generated help
26	// message
27	ShortDescription string
28
29	// A long description of the group. The long
30	// description is primarily used to present information on commands
31	// (Command embeds Group) in the built-in generated help and man pages.
32	LongDescription string
33
34	// The namespace of the group
35	Namespace string
36
37	// If true, the group is not displayed in the help or man page
38	Hidden bool
39
40	// The parent of the group or nil if it has no parent
41	parent interface{}
42
43	// All the options in the group
44	options []*Option
45
46	// All the subgroups
47	groups []*Group
48
49	// Whether the group represents the built-in help group
50	isBuiltinHelp bool
51
52	data interface{}
53}
54
55type scanHandler func(reflect.Value, *reflect.StructField) (bool, error)
56
57// AddGroup adds a new group to the command with the given name and data. The
58// data needs to be a pointer to a struct from which the fields indicate which
59// options are in the group.
60func (g *Group) AddGroup(shortDescription string, longDescription string, data interface{}) (*Group, error) {
61	group := newGroup(shortDescription, longDescription, data)
62
63	group.parent = g
64
65	if err := group.scan(); err != nil {
66		return nil, err
67	}
68
69	g.groups = append(g.groups, group)
70	return group, nil
71}
72
73// Groups returns the list of groups embedded in this group.
74func (g *Group) Groups() []*Group {
75	return g.groups
76}
77
78// Options returns the list of options in this group.
79func (g *Group) Options() []*Option {
80	return g.options
81}
82
83// Find locates the subgroup with the given short description and returns it.
84// If no such group can be found Find will return nil. Note that the description
85// is matched case insensitively.
86func (g *Group) Find(shortDescription string) *Group {
87	lshortDescription := strings.ToLower(shortDescription)
88
89	var ret *Group
90
91	g.eachGroup(func(gg *Group) {
92		if gg != g && strings.ToLower(gg.ShortDescription) == lshortDescription {
93			ret = gg
94		}
95	})
96
97	return ret
98}
99
100func (g *Group) findOption(matcher func(*Option) bool) (option *Option) {
101	g.eachGroup(func(g *Group) {
102		for _, opt := range g.options {
103			if option == nil && matcher(opt) {
104				option = opt
105			}
106		}
107	})
108
109	return option
110}
111
112// FindOptionByLongName finds an option that is part of the group, or any of its
113// subgroups, by matching its long name (including the option namespace).
114func (g *Group) FindOptionByLongName(longName string) *Option {
115	return g.findOption(func(option *Option) bool {
116		return option.LongNameWithNamespace() == longName
117	})
118}
119
120// FindOptionByShortName finds an option that is part of the group, or any of
121// its subgroups, by matching its short name.
122func (g *Group) FindOptionByShortName(shortName rune) *Option {
123	return g.findOption(func(option *Option) bool {
124		return option.ShortName == shortName
125	})
126}
127
128func newGroup(shortDescription string, longDescription string, data interface{}) *Group {
129	return &Group{
130		ShortDescription: shortDescription,
131		LongDescription:  longDescription,
132
133		data: data,
134	}
135}
136
137func (g *Group) optionByName(name string, namematch func(*Option, string) bool) *Option {
138	prio := 0
139	var retopt *Option
140
141	g.eachGroup(func(g *Group) {
142		for _, opt := range g.options {
143			if namematch != nil && namematch(opt, name) && prio < 4 {
144				retopt = opt
145				prio = 4
146			}
147
148			if name == opt.field.Name && prio < 3 {
149				retopt = opt
150				prio = 3
151			}
152
153			if name == opt.LongNameWithNamespace() && prio < 2 {
154				retopt = opt
155				prio = 2
156			}
157
158			if opt.ShortName != 0 && name == string(opt.ShortName) && prio < 1 {
159				retopt = opt
160				prio = 1
161			}
162		}
163	})
164
165	return retopt
166}
167
168func (g *Group) eachGroup(f func(*Group)) {
169	f(g)
170
171	for _, gg := range g.groups {
172		gg.eachGroup(f)
173	}
174}
175
176func isStringFalsy(s string) bool {
177	return s == "" || s == "false" || s == "no" || s == "0"
178}
179
180func (g *Group) scanStruct(realval reflect.Value, sfield *reflect.StructField, handler scanHandler) error {
181	stype := realval.Type()
182
183	if sfield != nil {
184		if ok, err := handler(realval, sfield); err != nil {
185			return err
186		} else if ok {
187			return nil
188		}
189	}
190
191	for i := 0; i < stype.NumField(); i++ {
192		field := stype.Field(i)
193
194		// PkgName is set only for non-exported fields, which we ignore
195		if field.PkgPath != "" && !field.Anonymous {
196			continue
197		}
198
199		mtag := newMultiTag(string(field.Tag))
200
201		if err := mtag.Parse(); err != nil {
202			return err
203		}
204
205		// Skip fields with the no-flag tag
206		if mtag.Get("no-flag") != "" {
207			continue
208		}
209
210		// Dive deep into structs or pointers to structs
211		kind := field.Type.Kind()
212		fld := realval.Field(i)
213
214		if kind == reflect.Struct {
215			if err := g.scanStruct(fld, &field, handler); err != nil {
216				return err
217			}
218		} else if kind == reflect.Ptr && field.Type.Elem().Kind() == reflect.Struct {
219			flagCountBefore := len(g.options) + len(g.groups)
220
221			if fld.IsNil() {
222				fld = reflect.New(fld.Type().Elem())
223			}
224
225			if err := g.scanStruct(reflect.Indirect(fld), &field, handler); err != nil {
226				return err
227			}
228
229			if len(g.options)+len(g.groups) != flagCountBefore {
230				realval.Field(i).Set(fld)
231			}
232		}
233
234		longname := mtag.Get("long")
235		shortname := mtag.Get("short")
236
237		// Need at least either a short or long name
238		if longname == "" && shortname == "" && mtag.Get("ini-name") == "" {
239			continue
240		}
241
242		short := rune(0)
243		rc := utf8.RuneCountInString(shortname)
244
245		if rc > 1 {
246			return newErrorf(ErrShortNameTooLong,
247				"short names can only be 1 character long, not `%s'",
248				shortname)
249
250		} else if rc == 1 {
251			short, _ = utf8.DecodeRuneInString(shortname)
252		}
253
254		description := mtag.Get("description")
255		def := mtag.GetMany("default")
256
257		optionalValue := mtag.GetMany("optional-value")
258		valueName := mtag.Get("value-name")
259		defaultMask := mtag.Get("default-mask")
260
261		optional := !isStringFalsy(mtag.Get("optional"))
262		required := !isStringFalsy(mtag.Get("required"))
263		choices := mtag.GetMany("choice")
264		hidden := !isStringFalsy(mtag.Get("hidden"))
265
266		option := &Option{
267			Description:      description,
268			ShortName:        short,
269			LongName:         longname,
270			Default:          def,
271			EnvDefaultKey:    mtag.Get("env"),
272			EnvDefaultDelim:  mtag.Get("env-delim"),
273			OptionalArgument: optional,
274			OptionalValue:    optionalValue,
275			Required:         required,
276			ValueName:        valueName,
277			DefaultMask:      defaultMask,
278			Choices:          choices,
279			Hidden:           hidden,
280
281			group: g,
282
283			field: field,
284			value: realval.Field(i),
285			tag:   mtag,
286		}
287
288		if option.isBool() && option.Default != nil {
289			return newErrorf(ErrInvalidTag,
290				"boolean flag `%s' may not have default values, they always default to `false' and can only be turned on",
291				option.shortAndLongName())
292		}
293
294		g.options = append(g.options, option)
295	}
296
297	return nil
298}
299
300func (g *Group) checkForDuplicateFlags() *Error {
301	shortNames := make(map[rune]*Option)
302	longNames := make(map[string]*Option)
303
304	var duplicateError *Error
305
306	g.eachGroup(func(g *Group) {
307		for _, option := range g.options {
308			if option.LongName != "" {
309				longName := option.LongNameWithNamespace()
310
311				if otherOption, ok := longNames[longName]; ok {
312					duplicateError = newErrorf(ErrDuplicatedFlag, "option `%s' uses the same long name as option `%s'", option, otherOption)
313					return
314				}
315				longNames[longName] = option
316			}
317			if option.ShortName != 0 {
318				if otherOption, ok := shortNames[option.ShortName]; ok {
319					duplicateError = newErrorf(ErrDuplicatedFlag, "option `%s' uses the same short name as option `%s'", option, otherOption)
320					return
321				}
322				shortNames[option.ShortName] = option
323			}
324		}
325	})
326
327	return duplicateError
328}
329
330func (g *Group) scanSubGroupHandler(realval reflect.Value, sfield *reflect.StructField) (bool, error) {
331	mtag := newMultiTag(string(sfield.Tag))
332
333	if err := mtag.Parse(); err != nil {
334		return true, err
335	}
336
337	subgroup := mtag.Get("group")
338
339	if len(subgroup) != 0 {
340		var ptrval reflect.Value
341
342		if realval.Kind() == reflect.Ptr {
343			ptrval = realval
344
345			if ptrval.IsNil() {
346				ptrval.Set(reflect.New(ptrval.Type()))
347			}
348		} else {
349			ptrval = realval.Addr()
350		}
351
352		description := mtag.Get("description")
353
354		group, err := g.AddGroup(subgroup, description, ptrval.Interface())
355
356		if err != nil {
357			return true, err
358		}
359
360		group.Namespace = mtag.Get("namespace")
361		group.Hidden = mtag.Get("hidden") != ""
362
363		return true, nil
364	}
365
366	return false, nil
367}
368
369func (g *Group) scanType(handler scanHandler) error {
370	// Get all the public fields in the data struct
371	ptrval := reflect.ValueOf(g.data)
372
373	if ptrval.Type().Kind() != reflect.Ptr {
374		panic(ErrNotPointerToStruct)
375	}
376
377	stype := ptrval.Type().Elem()
378
379	if stype.Kind() != reflect.Struct {
380		panic(ErrNotPointerToStruct)
381	}
382
383	realval := reflect.Indirect(ptrval)
384
385	if err := g.scanStruct(realval, nil, handler); err != nil {
386		return err
387	}
388
389	if err := g.checkForDuplicateFlags(); err != nil {
390		return err
391	}
392
393	return nil
394}
395
396func (g *Group) scan() error {
397	return g.scanType(g.scanSubGroupHandler)
398}
399
400func (g *Group) groupByName(name string) *Group {
401	if len(name) == 0 {
402		return g
403	}
404
405	return g.Find(name)
406}
407