1// Copyright 2012 Jesse van den Kieboom. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package flags
6
7import (
8	"errors"
9	"reflect"
10	"strings"
11	"unicode/utf8"
12	"unsafe"
13)
14
15// ErrNotPointerToStruct indicates that a provided data container is not
16// a pointer to a struct. Only pointers to structs are valid data containers
17// for options.
18var ErrNotPointerToStruct = errors.New("provided data is not a pointer to struct")
19
20// Group represents an option group. Option groups can be used to logically
21// group options together under a description. Groups are only used to provide
22// more structure to options both for the user (as displayed in the help message)
23// and for you, since groups can be nested.
24type Group struct {
25	// A short description of the group. The
26	// short description is primarily used in the built-in generated help
27	// message
28	ShortDescription string
29
30	// A long description of the group. The long
31	// description is primarily used to present information on commands
32	// (Command embeds Group) in the built-in generated help and man pages.
33	LongDescription string
34
35	// The namespace of the group
36	Namespace string
37
38	// If true, the group is not displayed in the help or man page
39	Hidden bool
40
41	// The parent of the group or nil if it has no parent
42	parent interface{}
43
44	// All the options in the group
45	options []*Option
46
47	// All the subgroups
48	groups []*Group
49
50	// Whether the group represents the built-in help group
51	isBuiltinHelp bool
52
53	data interface{}
54}
55
56type scanHandler func(reflect.Value, *reflect.StructField) (bool, error)
57
58// AddGroup adds a new group to the command with the given name and data. The
59// data needs to be a pointer to a struct from which the fields indicate which
60// options are in the group.
61func (g *Group) AddGroup(shortDescription string, longDescription string, data interface{}) (*Group, error) {
62	group := newGroup(shortDescription, longDescription, data)
63
64	group.parent = g
65
66	if err := group.scan(); err != nil {
67		return nil, err
68	}
69
70	g.groups = append(g.groups, group)
71	return group, nil
72}
73
74// Groups returns the list of groups embedded in this group.
75func (g *Group) Groups() []*Group {
76	return g.groups
77}
78
79// Options returns the list of options in this group.
80func (g *Group) Options() []*Option {
81	return g.options
82}
83
84// Find locates the subgroup with the given short description and returns it.
85// If no such group can be found Find will return nil. Note that the description
86// is matched case insensitively.
87func (g *Group) Find(shortDescription string) *Group {
88	lshortDescription := strings.ToLower(shortDescription)
89
90	var ret *Group
91
92	g.eachGroup(func(gg *Group) {
93		if gg != g && strings.ToLower(gg.ShortDescription) == lshortDescription {
94			ret = gg
95		}
96	})
97
98	return ret
99}
100
101func (g *Group) findOption(matcher func(*Option) bool) (option *Option) {
102	g.eachGroup(func(g *Group) {
103		for _, opt := range g.options {
104			if option == nil && matcher(opt) {
105				option = opt
106			}
107		}
108	})
109
110	return option
111}
112
113// Find an option that is part of the group, or any of its subgroups,
114// by matching its long name (including the option namespace).
115func (g *Group) FindOptionByLongName(longName string) *Option {
116	return g.findOption(func(option *Option) bool {
117		return option.LongNameWithNamespace() == longName
118	})
119}
120
121// Find an option that is part of the group, or any of its subgroups,
122// by matching its short name.
123func (g *Group) FindOptionByShortName(shortName rune) *Option {
124	return g.findOption(func(option *Option) bool {
125		return option.ShortName == shortName
126	})
127}
128
129func newGroup(shortDescription string, longDescription string, data interface{}) *Group {
130	return &Group{
131		ShortDescription: shortDescription,
132		LongDescription:  longDescription,
133
134		data: data,
135	}
136}
137
138func (g *Group) optionByName(name string, namematch func(*Option, string) bool) *Option {
139	prio := 0
140	var retopt *Option
141
142	g.eachGroup(func(g *Group) {
143		for _, opt := range g.options {
144			if namematch != nil && namematch(opt, name) && prio < 4 {
145				retopt = opt
146				prio = 4
147			}
148
149			if name == opt.field.Name && prio < 3 {
150				retopt = opt
151				prio = 3
152			}
153
154			if name == opt.LongNameWithNamespace() && prio < 2 {
155				retopt = opt
156				prio = 2
157			}
158
159			if opt.ShortName != 0 && name == string(opt.ShortName) && prio < 1 {
160				retopt = opt
161				prio = 1
162			}
163		}
164	})
165
166	return retopt
167}
168
169func (g *Group) eachGroup(f func(*Group)) {
170	f(g)
171
172	for _, gg := range g.groups {
173		gg.eachGroup(f)
174	}
175}
176
177func (g *Group) scanStruct(realval reflect.Value, sfield *reflect.StructField, handler scanHandler) error {
178	stype := realval.Type()
179
180	if sfield != nil {
181		if ok, err := handler(realval, sfield); err != nil {
182			return err
183		} else if ok {
184			return nil
185		}
186	}
187
188	for i := 0; i < stype.NumField(); i++ {
189		field := stype.Field(i)
190
191		// PkgName is set only for non-exported fields, which we ignore
192		if field.PkgPath != "" && !field.Anonymous {
193			continue
194		}
195
196		mtag := newMultiTag(string(field.Tag))
197
198		if err := mtag.Parse(); err != nil {
199			return err
200		}
201
202		// Skip fields with the no-flag tag
203		if mtag.Get("no-flag") != "" {
204			continue
205		}
206
207		// Dive deep into structs or pointers to structs
208		kind := field.Type.Kind()
209		fld := realval.Field(i)
210
211		if kind == reflect.Struct {
212			if err := g.scanStruct(fld, &field, handler); err != nil {
213				return err
214			}
215		} else if kind == reflect.Ptr && field.Type.Elem().Kind() == reflect.Struct {
216			if fld.IsNil() {
217				fld.Set(reflect.New(fld.Type().Elem()))
218			}
219
220			if err := g.scanStruct(reflect.Indirect(fld), &field, handler); err != nil {
221				return err
222			}
223		}
224
225		longname := mtag.Get("long")
226		shortname := mtag.Get("short")
227
228		// Need at least either a short or long name
229		if longname == "" && shortname == "" && mtag.Get("ini-name") == "" {
230			continue
231		}
232
233		short := rune(0)
234		rc := utf8.RuneCountInString(shortname)
235
236		if rc > 1 {
237			return newErrorf(ErrShortNameTooLong,
238				"short names can only be 1 character long, not `%s'",
239				shortname)
240
241		} else if rc == 1 {
242			short, _ = utf8.DecodeRuneInString(shortname)
243		}
244
245		description := mtag.Get("description")
246		def := mtag.GetMany("default")
247
248		optionalValue := mtag.GetMany("optional-value")
249		valueName := mtag.Get("value-name")
250		defaultMask := mtag.Get("default-mask")
251
252		optional := (mtag.Get("optional") != "")
253		required := (mtag.Get("required") != "")
254		choices := mtag.GetMany("choice")
255		hidden := (mtag.Get("hidden") != "")
256
257		option := &Option{
258			Description:      description,
259			ShortName:        short,
260			LongName:         longname,
261			Default:          def,
262			EnvDefaultKey:    mtag.Get("env"),
263			EnvDefaultDelim:  mtag.Get("env-delim"),
264			OptionalArgument: optional,
265			OptionalValue:    optionalValue,
266			Required:         required,
267			ValueName:        valueName,
268			DefaultMask:      defaultMask,
269			Choices:          choices,
270			Hidden:           hidden,
271
272			group: g,
273
274			field: field,
275			value: realval.Field(i),
276			tag:   mtag,
277		}
278
279		g.options = append(g.options, option)
280	}
281
282	return nil
283}
284
285func (g *Group) checkForDuplicateFlags() *Error {
286	shortNames := make(map[rune]*Option)
287	longNames := make(map[string]*Option)
288
289	var duplicateError *Error
290
291	g.eachGroup(func(g *Group) {
292		for _, option := range g.options {
293			if option.LongName != "" {
294				longName := option.LongNameWithNamespace()
295
296				if otherOption, ok := longNames[longName]; ok {
297					duplicateError = newErrorf(ErrDuplicatedFlag, "option `%s' uses the same long name as option `%s'", option, otherOption)
298					return
299				}
300				longNames[longName] = option
301			}
302			if option.ShortName != 0 {
303				if otherOption, ok := shortNames[option.ShortName]; ok {
304					duplicateError = newErrorf(ErrDuplicatedFlag, "option `%s' uses the same short name as option `%s'", option, otherOption)
305					return
306				}
307				shortNames[option.ShortName] = option
308			}
309		}
310	})
311
312	return duplicateError
313}
314
315func (g *Group) scanSubGroupHandler(realval reflect.Value, sfield *reflect.StructField) (bool, error) {
316	mtag := newMultiTag(string(sfield.Tag))
317
318	if err := mtag.Parse(); err != nil {
319		return true, err
320	}
321
322	subgroup := mtag.Get("group")
323
324	if len(subgroup) != 0 {
325		ptrval := reflect.NewAt(realval.Type(), unsafe.Pointer(realval.UnsafeAddr()))
326		description := mtag.Get("description")
327
328		group, err := g.AddGroup(subgroup, description, ptrval.Interface())
329		if err != nil {
330			return true, err
331		}
332
333		group.Namespace = mtag.Get("namespace")
334		group.Hidden = mtag.Get("hidden") != ""
335
336		return true, nil
337	}
338
339	return false, nil
340}
341
342func (g *Group) scanType(handler scanHandler) error {
343	// Get all the public fields in the data struct
344	ptrval := reflect.ValueOf(g.data)
345
346	if ptrval.Type().Kind() != reflect.Ptr {
347		panic(ErrNotPointerToStruct)
348	}
349
350	stype := ptrval.Type().Elem()
351
352	if stype.Kind() != reflect.Struct {
353		panic(ErrNotPointerToStruct)
354	}
355
356	realval := reflect.Indirect(ptrval)
357
358	if err := g.scanStruct(realval, nil, handler); err != nil {
359		return err
360	}
361
362	if err := g.checkForDuplicateFlags(); err != nil {
363		return err
364	}
365
366	return nil
367}
368
369func (g *Group) scan() error {
370	return g.scanType(g.scanSubGroupHandler)
371}
372
373func (g *Group) groupByName(name string) *Group {
374	if len(name) == 0 {
375		return g
376	}
377
378	return g.Find(name)
379}
380