1// Copyright 2012 Jesse van den Kieboom. All rights reserved. 2// Use of this source code is governed by a BSD-style 3// license that can be found in the LICENSE file. 4 5package flags 6 7import ( 8 "fmt" 9 "reflect" 10 "strconv" 11 "strings" 12 "time" 13) 14 15// Marshaler is the interface implemented by types that can marshal themselves 16// to a string representation of the flag. 17type Marshaler interface { 18 // MarshalFlag marshals a flag value to its string representation. 19 MarshalFlag() (string, error) 20} 21 22// Unmarshaler is the interface implemented by types that can unmarshal a flag 23// argument to themselves. The provided value is directly passed from the 24// command line. 25type Unmarshaler interface { 26 // UnmarshalFlag unmarshals a string value representation to the flag 27 // value (which therefore needs to be a pointer receiver). 28 UnmarshalFlag(value string) error 29} 30 31func getBase(options multiTag, base int) (int, error) { 32 sbase := options.Get("base") 33 34 var err error 35 var ivbase int64 36 37 if sbase != "" { 38 ivbase, err = strconv.ParseInt(sbase, 10, 32) 39 base = int(ivbase) 40 } 41 42 return base, err 43} 44 45func convertMarshal(val reflect.Value) (bool, string, error) { 46 // Check first for the Marshaler interface 47 if val.Type().NumMethod() > 0 && val.CanInterface() { 48 if marshaler, ok := val.Interface().(Marshaler); ok { 49 ret, err := marshaler.MarshalFlag() 50 return true, ret, err 51 } 52 } 53 54 return false, "", nil 55} 56 57func convertToString(val reflect.Value, options multiTag) (string, error) { 58 if ok, ret, err := convertMarshal(val); ok { 59 return ret, err 60 } 61 62 tp := val.Type() 63 64 // Support for time.Duration 65 if tp == reflect.TypeOf((*time.Duration)(nil)).Elem() { 66 stringer := val.Interface().(fmt.Stringer) 67 return stringer.String(), nil 68 } 69 70 switch tp.Kind() { 71 case reflect.String: 72 return val.String(), nil 73 case reflect.Bool: 74 if val.Bool() { 75 return "true", nil 76 } 77 78 return "false", nil 79 case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: 80 base, err := getBase(options, 10) 81 82 if err != nil { 83 return "", err 84 } 85 86 return strconv.FormatInt(val.Int(), base), nil 87 case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: 88 base, err := getBase(options, 10) 89 90 if err != nil { 91 return "", err 92 } 93 94 return strconv.FormatUint(val.Uint(), base), nil 95 case reflect.Float32, reflect.Float64: 96 return strconv.FormatFloat(val.Float(), 'g', -1, tp.Bits()), nil 97 case reflect.Slice: 98 if val.Len() == 0 { 99 return "", nil 100 } 101 102 ret := "[" 103 104 for i := 0; i < val.Len(); i++ { 105 if i != 0 { 106 ret += ", " 107 } 108 109 item, err := convertToString(val.Index(i), options) 110 111 if err != nil { 112 return "", err 113 } 114 115 ret += item 116 } 117 118 return ret + "]", nil 119 case reflect.Map: 120 ret := "{" 121 122 for i, key := range val.MapKeys() { 123 if i != 0 { 124 ret += ", " 125 } 126 127 keyitem, err := convertToString(key, options) 128 129 if err != nil { 130 return "", err 131 } 132 133 item, err := convertToString(val.MapIndex(key), options) 134 135 if err != nil { 136 return "", err 137 } 138 139 ret += keyitem + ":" + item 140 } 141 142 return ret + "}", nil 143 case reflect.Ptr: 144 return convertToString(reflect.Indirect(val), options) 145 case reflect.Interface: 146 if !val.IsNil() { 147 return convertToString(val.Elem(), options) 148 } 149 } 150 151 return "", nil 152} 153 154func convertUnmarshal(val string, retval reflect.Value) (bool, error) { 155 if retval.Type().NumMethod() > 0 && retval.CanInterface() { 156 if unmarshaler, ok := retval.Interface().(Unmarshaler); ok { 157 if retval.IsNil() { 158 retval.Set(reflect.New(retval.Type().Elem())) 159 160 // Re-assign from the new value 161 unmarshaler = retval.Interface().(Unmarshaler) 162 } 163 164 return true, unmarshaler.UnmarshalFlag(val) 165 } 166 } 167 168 if retval.Type().Kind() != reflect.Ptr && retval.CanAddr() { 169 return convertUnmarshal(val, retval.Addr()) 170 } 171 172 if retval.Type().Kind() == reflect.Interface && !retval.IsNil() { 173 return convertUnmarshal(val, retval.Elem()) 174 } 175 176 return false, nil 177} 178 179func convert(val string, retval reflect.Value, options multiTag) error { 180 if ok, err := convertUnmarshal(val, retval); ok { 181 return err 182 } 183 184 tp := retval.Type() 185 186 // Support for time.Duration 187 if tp == reflect.TypeOf((*time.Duration)(nil)).Elem() { 188 parsed, err := time.ParseDuration(val) 189 190 if err != nil { 191 return err 192 } 193 194 retval.SetInt(int64(parsed)) 195 return nil 196 } 197 198 switch tp.Kind() { 199 case reflect.String: 200 retval.SetString(val) 201 case reflect.Bool: 202 if val == "" { 203 retval.SetBool(true) 204 } else { 205 b, err := strconv.ParseBool(val) 206 207 if err != nil { 208 return err 209 } 210 211 retval.SetBool(b) 212 } 213 case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: 214 base, err := getBase(options, 10) 215 216 if err != nil { 217 return err 218 } 219 220 parsed, err := strconv.ParseInt(val, base, tp.Bits()) 221 222 if err != nil { 223 return err 224 } 225 226 retval.SetInt(parsed) 227 case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: 228 base, err := getBase(options, 10) 229 230 if err != nil { 231 return err 232 } 233 234 parsed, err := strconv.ParseUint(val, base, tp.Bits()) 235 236 if err != nil { 237 return err 238 } 239 240 retval.SetUint(parsed) 241 case reflect.Float32, reflect.Float64: 242 parsed, err := strconv.ParseFloat(val, tp.Bits()) 243 244 if err != nil { 245 return err 246 } 247 248 retval.SetFloat(parsed) 249 case reflect.Slice: 250 elemtp := tp.Elem() 251 252 elemvalptr := reflect.New(elemtp) 253 elemval := reflect.Indirect(elemvalptr) 254 255 if err := convert(val, elemval, options); err != nil { 256 return err 257 } 258 259 retval.Set(reflect.Append(retval, elemval)) 260 case reflect.Map: 261 parts := strings.SplitN(val, ":", 2) 262 263 key := parts[0] 264 var value string 265 266 if len(parts) == 2 { 267 value = parts[1] 268 } 269 270 keytp := tp.Key() 271 keyval := reflect.New(keytp) 272 273 if err := convert(key, keyval, options); err != nil { 274 return err 275 } 276 277 valuetp := tp.Elem() 278 valueval := reflect.New(valuetp) 279 280 if err := convert(value, valueval, options); err != nil { 281 return err 282 } 283 284 if retval.IsNil() { 285 retval.Set(reflect.MakeMap(tp)) 286 } 287 288 retval.SetMapIndex(reflect.Indirect(keyval), reflect.Indirect(valueval)) 289 case reflect.Ptr: 290 if retval.IsNil() { 291 retval.Set(reflect.New(retval.Type().Elem())) 292 } 293 294 return convert(val, reflect.Indirect(retval), options) 295 case reflect.Interface: 296 if !retval.IsNil() { 297 return convert(val, retval.Elem(), options) 298 } 299 } 300 301 return nil 302} 303 304func isPrint(s string) bool { 305 for _, c := range s { 306 if !strconv.IsPrint(c) { 307 return false 308 } 309 } 310 311 return true 312} 313 314func quoteIfNeeded(s string) string { 315 if !isPrint(s) { 316 return strconv.Quote(s) 317 } 318 319 return s 320} 321 322func quoteIfNeededV(s []string) []string { 323 ret := make([]string, len(s)) 324 325 for i, v := range s { 326 ret[i] = quoteIfNeeded(v) 327 } 328 329 return ret 330} 331 332func quoteV(s []string) []string { 333 ret := make([]string, len(s)) 334 335 for i, v := range s { 336 ret[i] = strconv.Quote(v) 337 } 338 339 return ret 340} 341 342func unquoteIfPossible(s string) (string, error) { 343 if len(s) == 0 || s[0] != '"' { 344 return s, nil 345 } 346 347 return strconv.Unquote(s) 348} 349