1package cli 2 3import ( 4 "context" 5 "errors" 6 "flag" 7 "fmt" 8 "strings" 9) 10 11// Context is a type that is passed through to 12// each Handler action in a cli application. Context 13// can be used to retrieve context-specific args and 14// parsed command-line options. 15type Context struct { 16 context.Context 17 App *App 18 Command *Command 19 shellComplete bool 20 flagSet *flag.FlagSet 21 parentContext *Context 22} 23 24// NewContext creates a new context. For use in when invoking an App or Command action. 25func NewContext(app *App, set *flag.FlagSet, parentCtx *Context) *Context { 26 c := &Context{App: app, flagSet: set, parentContext: parentCtx} 27 if parentCtx != nil { 28 c.Context = parentCtx.Context 29 c.shellComplete = parentCtx.shellComplete 30 if parentCtx.flagSet == nil { 31 parentCtx.flagSet = &flag.FlagSet{} 32 } 33 } 34 35 c.Command = &Command{} 36 37 if c.Context == nil { 38 c.Context = context.Background() 39 } 40 41 return c 42} 43 44// NumFlags returns the number of flags set 45func (c *Context) NumFlags() int { 46 return c.flagSet.NFlag() 47} 48 49// Set sets a context flag to a value. 50func (c *Context) Set(name, value string) error { 51 return c.flagSet.Set(name, value) 52} 53 54// IsSet determines if the flag was actually set 55func (c *Context) IsSet(name string) bool { 56 if fs := lookupFlagSet(name, c); fs != nil { 57 if fs := lookupFlagSet(name, c); fs != nil { 58 isSet := false 59 fs.Visit(func(f *flag.Flag) { 60 if f.Name == name { 61 isSet = true 62 } 63 }) 64 if isSet { 65 return true 66 } 67 } 68 69 f := lookupFlag(name, c) 70 if f == nil { 71 return false 72 } 73 74 return f.IsSet() 75 } 76 77 return false 78} 79 80// LocalFlagNames returns a slice of flag names used in this context. 81func (c *Context) LocalFlagNames() []string { 82 var names []string 83 c.flagSet.Visit(makeFlagNameVisitor(&names)) 84 return names 85} 86 87// FlagNames returns a slice of flag names used by the this context and all of 88// its parent contexts. 89func (c *Context) FlagNames() []string { 90 var names []string 91 for _, ctx := range c.Lineage() { 92 ctx.flagSet.Visit(makeFlagNameVisitor(&names)) 93 } 94 return names 95} 96 97// Lineage returns *this* context and all of its ancestor contexts in order from 98// child to parent 99func (c *Context) Lineage() []*Context { 100 var lineage []*Context 101 102 for cur := c; cur != nil; cur = cur.parentContext { 103 lineage = append(lineage, cur) 104 } 105 106 return lineage 107} 108 109// Value returns the value of the flag corresponding to `name` 110func (c *Context) Value(name string) interface{} { 111 return c.flagSet.Lookup(name).Value.(flag.Getter).Get() 112} 113 114// Args returns the command line arguments associated with the context. 115func (c *Context) Args() Args { 116 ret := args(c.flagSet.Args()) 117 return &ret 118} 119 120// NArg returns the number of the command line arguments. 121func (c *Context) NArg() int { 122 return c.Args().Len() 123} 124 125func lookupFlag(name string, ctx *Context) Flag { 126 for _, c := range ctx.Lineage() { 127 if c.Command == nil { 128 continue 129 } 130 131 for _, f := range c.Command.Flags { 132 for _, n := range f.Names() { 133 if n == name { 134 return f 135 } 136 } 137 } 138 } 139 140 if ctx.App != nil { 141 for _, f := range ctx.App.Flags { 142 for _, n := range f.Names() { 143 if n == name { 144 return f 145 } 146 } 147 } 148 } 149 150 return nil 151} 152 153func lookupFlagSet(name string, ctx *Context) *flag.FlagSet { 154 for _, c := range ctx.Lineage() { 155 if f := c.flagSet.Lookup(name); f != nil { 156 return c.flagSet 157 } 158 } 159 160 return nil 161} 162 163func copyFlag(name string, ff *flag.Flag, set *flag.FlagSet) { 164 switch ff.Value.(type) { 165 case Serializer: 166 _ = set.Set(name, ff.Value.(Serializer).Serialize()) 167 default: 168 _ = set.Set(name, ff.Value.String()) 169 } 170} 171 172func normalizeFlags(flags []Flag, set *flag.FlagSet) error { 173 visited := make(map[string]bool) 174 set.Visit(func(f *flag.Flag) { 175 visited[f.Name] = true 176 }) 177 for _, f := range flags { 178 parts := f.Names() 179 if len(parts) == 1 { 180 continue 181 } 182 var ff *flag.Flag 183 for _, name := range parts { 184 name = strings.Trim(name, " ") 185 if visited[name] { 186 if ff != nil { 187 return errors.New("Cannot use two forms of the same flag: " + name + " " + ff.Name) 188 } 189 ff = set.Lookup(name) 190 } 191 } 192 if ff == nil { 193 continue 194 } 195 for _, name := range parts { 196 name = strings.Trim(name, " ") 197 if !visited[name] { 198 copyFlag(name, ff, set) 199 } 200 } 201 } 202 return nil 203} 204 205func makeFlagNameVisitor(names *[]string) func(*flag.Flag) { 206 return func(f *flag.Flag) { 207 nameParts := strings.Split(f.Name, ",") 208 name := strings.TrimSpace(nameParts[0]) 209 210 for _, part := range nameParts { 211 part = strings.TrimSpace(part) 212 if len(part) > len(name) { 213 name = part 214 } 215 } 216 217 if name != "" { 218 *names = append(*names, name) 219 } 220 } 221} 222 223type requiredFlagsErr interface { 224 error 225 getMissingFlags() []string 226} 227 228type errRequiredFlags struct { 229 missingFlags []string 230} 231 232func (e *errRequiredFlags) Error() string { 233 numberOfMissingFlags := len(e.missingFlags) 234 if numberOfMissingFlags == 1 { 235 return fmt.Sprintf("Required flag %q not set", e.missingFlags[0]) 236 } 237 joinedMissingFlags := strings.Join(e.missingFlags, ", ") 238 return fmt.Sprintf("Required flags %q not set", joinedMissingFlags) 239} 240 241func (e *errRequiredFlags) getMissingFlags() []string { 242 return e.missingFlags 243} 244 245func checkRequiredFlags(flags []Flag, context *Context) requiredFlagsErr { 246 var missingFlags []string 247 for _, f := range flags { 248 if rf, ok := f.(RequiredFlag); ok && rf.IsRequired() { 249 var flagPresent bool 250 var flagName string 251 252 for _, key := range f.Names() { 253 if len(key) > 1 { 254 flagName = key 255 } 256 257 if context.IsSet(strings.TrimSpace(key)) { 258 flagPresent = true 259 } 260 } 261 262 if !flagPresent && flagName != "" { 263 missingFlags = append(missingFlags, flagName) 264 } 265 } 266 } 267 268 if len(missingFlags) != 0 { 269 return &errRequiredFlags{missingFlags: missingFlags} 270 } 271 272 return nil 273} 274