1package cli
2
3import (
4	"context"
5	"errors"
6	"flag"
7	"fmt"
8	"strings"
9)
10
11// Context is a type that is passed through to
12// each Handler action in a cli application. Context
13// can be used to retrieve context-specific args and
14// parsed command-line options.
15type Context struct {
16	context.Context
17	App           *App
18	Command       *Command
19	shellComplete bool
20	flagSet       *flag.FlagSet
21	parentContext *Context
22}
23
24// NewContext creates a new context. For use in when invoking an App or Command action.
25func NewContext(app *App, set *flag.FlagSet, parentCtx *Context) *Context {
26	c := &Context{App: app, flagSet: set, parentContext: parentCtx}
27	if parentCtx != nil {
28		c.Context = parentCtx.Context
29		c.shellComplete = parentCtx.shellComplete
30		if parentCtx.flagSet == nil {
31			parentCtx.flagSet = &flag.FlagSet{}
32		}
33	}
34
35	c.Command = &Command{}
36
37	if c.Context == nil {
38		c.Context = context.Background()
39	}
40
41	return c
42}
43
44// NumFlags returns the number of flags set
45func (c *Context) NumFlags() int {
46	return c.flagSet.NFlag()
47}
48
49// Set sets a context flag to a value.
50func (c *Context) Set(name, value string) error {
51	return c.flagSet.Set(name, value)
52}
53
54// IsSet determines if the flag was actually set
55func (c *Context) IsSet(name string) bool {
56	if fs := lookupFlagSet(name, c); fs != nil {
57		if fs := lookupFlagSet(name, c); fs != nil {
58			isSet := false
59			fs.Visit(func(f *flag.Flag) {
60				if f.Name == name {
61					isSet = true
62				}
63			})
64			if isSet {
65				return true
66			}
67		}
68
69		f := lookupFlag(name, c)
70		if f == nil {
71			return false
72		}
73
74		return f.IsSet()
75	}
76
77	return false
78}
79
80// LocalFlagNames returns a slice of flag names used in this context.
81func (c *Context) LocalFlagNames() []string {
82	var names []string
83	c.flagSet.Visit(makeFlagNameVisitor(&names))
84	return names
85}
86
87// FlagNames returns a slice of flag names used by the this context and all of
88// its parent contexts.
89func (c *Context) FlagNames() []string {
90	var names []string
91	for _, ctx := range c.Lineage() {
92		ctx.flagSet.Visit(makeFlagNameVisitor(&names))
93	}
94	return names
95}
96
97// Lineage returns *this* context and all of its ancestor contexts in order from
98// child to parent
99func (c *Context) Lineage() []*Context {
100	var lineage []*Context
101
102	for cur := c; cur != nil; cur = cur.parentContext {
103		lineage = append(lineage, cur)
104	}
105
106	return lineage
107}
108
109// Value returns the value of the flag corresponding to `name`
110func (c *Context) Value(name string) interface{} {
111	return c.flagSet.Lookup(name).Value.(flag.Getter).Get()
112}
113
114// Args returns the command line arguments associated with the context.
115func (c *Context) Args() Args {
116	ret := args(c.flagSet.Args())
117	return &ret
118}
119
120// NArg returns the number of the command line arguments.
121func (c *Context) NArg() int {
122	return c.Args().Len()
123}
124
125func lookupFlag(name string, ctx *Context) Flag {
126	for _, c := range ctx.Lineage() {
127		if c.Command == nil {
128			continue
129		}
130
131		for _, f := range c.Command.Flags {
132			for _, n := range f.Names() {
133				if n == name {
134					return f
135				}
136			}
137		}
138	}
139
140	if ctx.App != nil {
141		for _, f := range ctx.App.Flags {
142			for _, n := range f.Names() {
143				if n == name {
144					return f
145				}
146			}
147		}
148	}
149
150	return nil
151}
152
153func lookupFlagSet(name string, ctx *Context) *flag.FlagSet {
154	for _, c := range ctx.Lineage() {
155		if f := c.flagSet.Lookup(name); f != nil {
156			return c.flagSet
157		}
158	}
159
160	return nil
161}
162
163func copyFlag(name string, ff *flag.Flag, set *flag.FlagSet) {
164	switch ff.Value.(type) {
165	case Serializer:
166		_ = set.Set(name, ff.Value.(Serializer).Serialize())
167	default:
168		_ = set.Set(name, ff.Value.String())
169	}
170}
171
172func normalizeFlags(flags []Flag, set *flag.FlagSet) error {
173	visited := make(map[string]bool)
174	set.Visit(func(f *flag.Flag) {
175		visited[f.Name] = true
176	})
177	for _, f := range flags {
178		parts := f.Names()
179		if len(parts) == 1 {
180			continue
181		}
182		var ff *flag.Flag
183		for _, name := range parts {
184			name = strings.Trim(name, " ")
185			if visited[name] {
186				if ff != nil {
187					return errors.New("Cannot use two forms of the same flag: " + name + " " + ff.Name)
188				}
189				ff = set.Lookup(name)
190			}
191		}
192		if ff == nil {
193			continue
194		}
195		for _, name := range parts {
196			name = strings.Trim(name, " ")
197			if !visited[name] {
198				copyFlag(name, ff, set)
199			}
200		}
201	}
202	return nil
203}
204
205func makeFlagNameVisitor(names *[]string) func(*flag.Flag) {
206	return func(f *flag.Flag) {
207		nameParts := strings.Split(f.Name, ",")
208		name := strings.TrimSpace(nameParts[0])
209
210		for _, part := range nameParts {
211			part = strings.TrimSpace(part)
212			if len(part) > len(name) {
213				name = part
214			}
215		}
216
217		if name != "" {
218			*names = append(*names, name)
219		}
220	}
221}
222
223type requiredFlagsErr interface {
224	error
225	getMissingFlags() []string
226}
227
228type errRequiredFlags struct {
229	missingFlags []string
230}
231
232func (e *errRequiredFlags) Error() string {
233	numberOfMissingFlags := len(e.missingFlags)
234	if numberOfMissingFlags == 1 {
235		return fmt.Sprintf("Required flag %q not set", e.missingFlags[0])
236	}
237	joinedMissingFlags := strings.Join(e.missingFlags, ", ")
238	return fmt.Sprintf("Required flags %q not set", joinedMissingFlags)
239}
240
241func (e *errRequiredFlags) getMissingFlags() []string {
242	return e.missingFlags
243}
244
245func checkRequiredFlags(flags []Flag, context *Context) requiredFlagsErr {
246	var missingFlags []string
247	for _, f := range flags {
248		if rf, ok := f.(RequiredFlag); ok && rf.IsRequired() {
249			var flagPresent bool
250			var flagName string
251
252			for _, key := range f.Names() {
253				if len(key) > 1 {
254					flagName = key
255				}
256
257				if context.IsSet(strings.TrimSpace(key)) {
258					flagPresent = true
259				}
260			}
261
262			if !flagPresent && flagName != "" {
263				missingFlags = append(missingFlags, flagName)
264			}
265		}
266	}
267
268	if len(missingFlags) != 0 {
269		return &errRequiredFlags{missingFlags: missingFlags}
270	}
271
272	return nil
273}
274