1package cli
2
3import (
4	"errors"
5	"flag"
6	"fmt"
7	"os"
8	"reflect"
9	"strings"
10	"syscall"
11)
12
13// Context is a type that is passed through to
14// each Handler action in a cli application. Context
15// can be used to retrieve context-specific Args and
16// parsed command-line options.
17type Context struct {
18	App           *App
19	Command       Command
20	shellComplete bool
21	flagSet       *flag.FlagSet
22	setFlags      map[string]bool
23	parentContext *Context
24}
25
26// NewContext creates a new context. For use in when invoking an App or Command action.
27func NewContext(app *App, set *flag.FlagSet, parentCtx *Context) *Context {
28	c := &Context{App: app, flagSet: set, parentContext: parentCtx}
29
30	if parentCtx != nil {
31		c.shellComplete = parentCtx.shellComplete
32	}
33
34	return c
35}
36
37// NumFlags returns the number of flags set
38func (c *Context) NumFlags() int {
39	return c.flagSet.NFlag()
40}
41
42// Set sets a context flag to a value.
43func (c *Context) Set(name, value string) error {
44	c.setFlags = nil
45	return c.flagSet.Set(name, value)
46}
47
48// GlobalSet sets a context flag to a value on the global flagset
49func (c *Context) GlobalSet(name, value string) error {
50	globalContext(c).setFlags = nil
51	return globalContext(c).flagSet.Set(name, value)
52}
53
54// IsSet determines if the flag was actually set
55func (c *Context) IsSet(name string) bool {
56	if c.setFlags == nil {
57		c.setFlags = make(map[string]bool)
58
59		c.flagSet.Visit(func(f *flag.Flag) {
60			c.setFlags[f.Name] = true
61		})
62
63		c.flagSet.VisitAll(func(f *flag.Flag) {
64			if _, ok := c.setFlags[f.Name]; ok {
65				return
66			}
67			c.setFlags[f.Name] = false
68		})
69
70		// XXX hack to support IsSet for flags with EnvVar
71		//
72		// There isn't an easy way to do this with the current implementation since
73		// whether a flag was set via an environment variable is very difficult to
74		// determine here. Instead, we intend to introduce a backwards incompatible
75		// change in version 2 to add `IsSet` to the Flag interface to push the
76		// responsibility closer to where the information required to determine
77		// whether a flag is set by non-standard means such as environment
78		// variables is available.
79		//
80		// See https://github.com/urfave/cli/issues/294 for additional discussion
81		flags := c.Command.Flags
82		if c.Command.Name == "" { // cannot == Command{} since it contains slice types
83			if c.App != nil {
84				flags = c.App.Flags
85			}
86		}
87		for _, f := range flags {
88			eachName(f.GetName(), func(name string) {
89				if isSet, ok := c.setFlags[name]; isSet || !ok {
90					// Check if a flag is set
91					if isSet {
92						// If the flag is set, also set its other aliases
93						eachName(f.GetName(), func(name string) {
94							c.setFlags[name] = true
95						})
96					}
97
98					return
99				}
100
101				val := reflect.ValueOf(f)
102				if val.Kind() == reflect.Ptr {
103					val = val.Elem()
104				}
105
106				filePathValue := val.FieldByName("FilePath")
107				if filePathValue.IsValid() {
108					eachName(filePathValue.String(), func(filePath string) {
109						if _, err := os.Stat(filePath); err == nil {
110							c.setFlags[name] = true
111							return
112						}
113					})
114				}
115
116				envVarValue := val.FieldByName("EnvVar")
117				if envVarValue.IsValid() {
118					eachName(envVarValue.String(), func(envVar string) {
119						envVar = strings.TrimSpace(envVar)
120						if _, ok := syscall.Getenv(envVar); ok {
121							c.setFlags[name] = true
122							return
123						}
124					})
125				}
126			})
127		}
128	}
129
130	return c.setFlags[name]
131}
132
133// GlobalIsSet determines if the global flag was actually set
134func (c *Context) GlobalIsSet(name string) bool {
135	ctx := c
136	if ctx.parentContext != nil {
137		ctx = ctx.parentContext
138	}
139
140	for ; ctx != nil; ctx = ctx.parentContext {
141		if ctx.IsSet(name) {
142			return true
143		}
144	}
145	return false
146}
147
148// FlagNames returns a slice of flag names used in this context.
149func (c *Context) FlagNames() (names []string) {
150	for _, f := range c.Command.Flags {
151		name := strings.Split(f.GetName(), ",")[0]
152		if name == "help" {
153			continue
154		}
155		names = append(names, name)
156	}
157	return
158}
159
160// GlobalFlagNames returns a slice of global flag names used by the app.
161func (c *Context) GlobalFlagNames() (names []string) {
162	for _, f := range c.App.Flags {
163		name := strings.Split(f.GetName(), ",")[0]
164		if name == "help" || name == "version" {
165			continue
166		}
167		names = append(names, name)
168	}
169	return
170}
171
172// Parent returns the parent context, if any
173func (c *Context) Parent() *Context {
174	return c.parentContext
175}
176
177// value returns the value of the flag coressponding to `name`
178func (c *Context) value(name string) interface{} {
179	return c.flagSet.Lookup(name).Value.(flag.Getter).Get()
180}
181
182// Args contains apps console arguments
183type Args []string
184
185// Args returns the command line arguments associated with the context.
186func (c *Context) Args() Args {
187	args := Args(c.flagSet.Args())
188	return args
189}
190
191// NArg returns the number of the command line arguments.
192func (c *Context) NArg() int {
193	return len(c.Args())
194}
195
196// Get returns the nth argument, or else a blank string
197func (a Args) Get(n int) string {
198	if len(a) > n {
199		return a[n]
200	}
201	return ""
202}
203
204// First returns the first argument, or else a blank string
205func (a Args) First() string {
206	return a.Get(0)
207}
208
209// Tail returns the rest of the arguments (not the first one)
210// or else an empty string slice
211func (a Args) Tail() []string {
212	if len(a) >= 2 {
213		return []string(a)[1:]
214	}
215	return []string{}
216}
217
218// Present checks if there are any arguments present
219func (a Args) Present() bool {
220	return len(a) != 0
221}
222
223// Swap swaps arguments at the given indexes
224func (a Args) Swap(from, to int) error {
225	if from >= len(a) || to >= len(a) {
226		return errors.New("index out of range")
227	}
228	a[from], a[to] = a[to], a[from]
229	return nil
230}
231
232func globalContext(ctx *Context) *Context {
233	if ctx == nil {
234		return nil
235	}
236
237	for {
238		if ctx.parentContext == nil {
239			return ctx
240		}
241		ctx = ctx.parentContext
242	}
243}
244
245func lookupGlobalFlagSet(name string, ctx *Context) *flag.FlagSet {
246	if ctx.parentContext != nil {
247		ctx = ctx.parentContext
248	}
249	for ; ctx != nil; ctx = ctx.parentContext {
250		if f := ctx.flagSet.Lookup(name); f != nil {
251			return ctx.flagSet
252		}
253	}
254	return nil
255}
256
257func copyFlag(name string, ff *flag.Flag, set *flag.FlagSet) {
258	switch ff.Value.(type) {
259	case *StringSlice:
260	default:
261		_ = set.Set(name, ff.Value.String())
262	}
263}
264
265func normalizeFlags(flags []Flag, set *flag.FlagSet) error {
266	visited := make(map[string]bool)
267	set.Visit(func(f *flag.Flag) {
268		visited[f.Name] = true
269	})
270	for _, f := range flags {
271		parts := strings.Split(f.GetName(), ",")
272		if len(parts) == 1 {
273			continue
274		}
275		var ff *flag.Flag
276		for _, name := range parts {
277			name = strings.Trim(name, " ")
278			if visited[name] {
279				if ff != nil {
280					return errors.New("Cannot use two forms of the same flag: " + name + " " + ff.Name)
281				}
282				ff = set.Lookup(name)
283			}
284		}
285		if ff == nil {
286			continue
287		}
288		for _, name := range parts {
289			name = strings.Trim(name, " ")
290			if !visited[name] {
291				copyFlag(name, ff, set)
292			}
293		}
294	}
295	return nil
296}
297
298type requiredFlagsErr interface {
299	error
300	getMissingFlags() []string
301}
302
303type errRequiredFlags struct {
304	missingFlags []string
305}
306
307func (e *errRequiredFlags) Error() string {
308	numberOfMissingFlags := len(e.missingFlags)
309	if numberOfMissingFlags == 1 {
310		return fmt.Sprintf("Required flag %q not set", e.missingFlags[0])
311	}
312	joinedMissingFlags := strings.Join(e.missingFlags, ", ")
313	return fmt.Sprintf("Required flags %q not set", joinedMissingFlags)
314}
315
316func (e *errRequiredFlags) getMissingFlags() []string {
317	return e.missingFlags
318}
319
320func checkRequiredFlags(flags []Flag, context *Context) requiredFlagsErr {
321	var missingFlags []string
322	for _, f := range flags {
323		if rf, ok := f.(RequiredFlag); ok && rf.IsRequired() {
324			var flagPresent bool
325			var flagName string
326			for _, key := range strings.Split(f.GetName(), ",") {
327				key = strings.TrimSpace(key)
328				if len(key) > 1 {
329					flagName = key
330				}
331
332				if context.IsSet(key) {
333					flagPresent = true
334				}
335			}
336
337			if !flagPresent && flagName != "" {
338				missingFlags = append(missingFlags, flagName)
339			}
340		}
341	}
342
343	if len(missingFlags) != 0 {
344		return &errRequiredFlags{missingFlags: missingFlags}
345	}
346
347	return nil
348}
349