1// Copyright 2012 Jesse van den Kieboom. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package flags
6
7import (
8	"fmt"
9	"reflect"
10	"strconv"
11	"strings"
12	"time"
13)
14
15// Marshaler is the interface implemented by types that can marshal themselves
16// to a string representation of the flag.
17type Marshaler interface {
18	// MarshalFlag marshals a flag value to its string representation.
19	MarshalFlag() (string, error)
20}
21
22// Unmarshaler is the interface implemented by types that can unmarshal a flag
23// argument to themselves. The provided value is directly passed from the
24// command line.
25type Unmarshaler interface {
26	// UnmarshalFlag unmarshals a string value representation to the flag
27	// value (which therefore needs to be a pointer receiver).
28	UnmarshalFlag(value string) error
29}
30
31func getBase(options multiTag, base int) (int, error) {
32	sbase := options.Get("base")
33
34	var err error
35	var ivbase int64
36
37	if sbase != "" {
38		ivbase, err = strconv.ParseInt(sbase, 10, 32)
39		base = int(ivbase)
40	}
41
42	return base, err
43}
44
45func convertMarshal(val reflect.Value) (bool, string, error) {
46	// Check first for the Marshaler interface
47	if val.Type().NumMethod() > 0 && val.CanInterface() {
48		if marshaler, ok := val.Interface().(Marshaler); ok {
49			ret, err := marshaler.MarshalFlag()
50			return true, ret, err
51		}
52	}
53
54	return false, "", nil
55}
56
57func convertToString(val reflect.Value, options multiTag) (string, error) {
58	if ok, ret, err := convertMarshal(val); ok {
59		return ret, err
60	}
61
62	tp := val.Type()
63
64	// Support for time.Duration
65	if tp == reflect.TypeOf((*time.Duration)(nil)).Elem() {
66		stringer := val.Interface().(fmt.Stringer)
67		return stringer.String(), nil
68	}
69
70	switch tp.Kind() {
71	case reflect.String:
72		return val.String(), nil
73	case reflect.Bool:
74		if val.Bool() {
75			return "true", nil
76		}
77
78		return "false", nil
79	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
80		base, err := getBase(options, 10)
81
82		if err != nil {
83			return "", err
84		}
85
86		return strconv.FormatInt(val.Int(), base), nil
87	case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
88		base, err := getBase(options, 10)
89
90		if err != nil {
91			return "", err
92		}
93
94		return strconv.FormatUint(val.Uint(), base), nil
95	case reflect.Float32, reflect.Float64:
96		return strconv.FormatFloat(val.Float(), 'g', -1, tp.Bits()), nil
97	case reflect.Slice:
98		if val.Len() == 0 {
99			return "", nil
100		}
101
102		ret := "["
103
104		for i := 0; i < val.Len(); i++ {
105			if i != 0 {
106				ret += ", "
107			}
108
109			item, err := convertToString(val.Index(i), options)
110
111			if err != nil {
112				return "", err
113			}
114
115			ret += item
116		}
117
118		return ret + "]", nil
119	case reflect.Map:
120		ret := "{"
121
122		for i, key := range val.MapKeys() {
123			if i != 0 {
124				ret += ", "
125			}
126
127			keyitem, err := convertToString(key, options)
128
129			if err != nil {
130				return "", err
131			}
132
133			item, err := convertToString(val.MapIndex(key), options)
134
135			if err != nil {
136				return "", err
137			}
138
139			ret += keyitem + ":" + item
140		}
141
142		return ret + "}", nil
143	case reflect.Ptr:
144		return convertToString(reflect.Indirect(val), options)
145	case reflect.Interface:
146		if !val.IsNil() {
147			return convertToString(val.Elem(), options)
148		}
149	}
150
151	return "", nil
152}
153
154func convertUnmarshal(val string, retval reflect.Value) (bool, error) {
155	if retval.Type().NumMethod() > 0 && retval.CanInterface() {
156		if unmarshaler, ok := retval.Interface().(Unmarshaler); ok {
157			if retval.IsNil() {
158				retval.Set(reflect.New(retval.Type().Elem()))
159
160				// Re-assign from the new value
161				unmarshaler = retval.Interface().(Unmarshaler)
162			}
163
164			return true, unmarshaler.UnmarshalFlag(val)
165		}
166	}
167
168	if retval.Type().Kind() != reflect.Ptr && retval.CanAddr() {
169		return convertUnmarshal(val, retval.Addr())
170	}
171
172	if retval.Type().Kind() == reflect.Interface && !retval.IsNil() {
173		return convertUnmarshal(val, retval.Elem())
174	}
175
176	return false, nil
177}
178
179func convert(val string, retval reflect.Value, options multiTag) error {
180	if ok, err := convertUnmarshal(val, retval); ok {
181		return err
182	}
183
184	tp := retval.Type()
185
186	// Support for time.Duration
187	if tp == reflect.TypeOf((*time.Duration)(nil)).Elem() {
188		parsed, err := time.ParseDuration(val)
189
190		if err != nil {
191			return err
192		}
193
194		retval.SetInt(int64(parsed))
195		return nil
196	}
197
198	switch tp.Kind() {
199	case reflect.String:
200		retval.SetString(val)
201	case reflect.Bool:
202		if val == "" {
203			retval.SetBool(true)
204		} else {
205			b, err := strconv.ParseBool(val)
206
207			if err != nil {
208				return err
209			}
210
211			retval.SetBool(b)
212		}
213	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
214		base, err := getBase(options, 10)
215
216		if err != nil {
217			return err
218		}
219
220		parsed, err := strconv.ParseInt(val, base, tp.Bits())
221
222		if err != nil {
223			return err
224		}
225
226		retval.SetInt(parsed)
227	case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
228		base, err := getBase(options, 10)
229
230		if err != nil {
231			return err
232		}
233
234		parsed, err := strconv.ParseUint(val, base, tp.Bits())
235
236		if err != nil {
237			return err
238		}
239
240		retval.SetUint(parsed)
241	case reflect.Float32, reflect.Float64:
242		parsed, err := strconv.ParseFloat(val, tp.Bits())
243
244		if err != nil {
245			return err
246		}
247
248		retval.SetFloat(parsed)
249	case reflect.Slice:
250		elemtp := tp.Elem()
251
252		elemvalptr := reflect.New(elemtp)
253		elemval := reflect.Indirect(elemvalptr)
254
255		if err := convert(val, elemval, options); err != nil {
256			return err
257		}
258
259		retval.Set(reflect.Append(retval, elemval))
260	case reflect.Map:
261		parts := strings.SplitN(val, ":", 2)
262
263		key := parts[0]
264		var value string
265
266		if len(parts) == 2 {
267			value = parts[1]
268		}
269
270		keytp := tp.Key()
271		keyval := reflect.New(keytp)
272
273		if err := convert(key, keyval, options); err != nil {
274			return err
275		}
276
277		valuetp := tp.Elem()
278		valueval := reflect.New(valuetp)
279
280		if err := convert(value, valueval, options); err != nil {
281			return err
282		}
283
284		if retval.IsNil() {
285			retval.Set(reflect.MakeMap(tp))
286		}
287
288		retval.SetMapIndex(reflect.Indirect(keyval), reflect.Indirect(valueval))
289	case reflect.Ptr:
290		if retval.IsNil() {
291			retval.Set(reflect.New(retval.Type().Elem()))
292		}
293
294		return convert(val, reflect.Indirect(retval), options)
295	case reflect.Interface:
296		if !retval.IsNil() {
297			return convert(val, retval.Elem(), options)
298		}
299	}
300
301	return nil
302}
303
304func isPrint(s string) bool {
305	for _, c := range s {
306		if !strconv.IsPrint(c) {
307			return false
308		}
309	}
310
311	return true
312}
313
314func quoteIfNeeded(s string) string {
315	if !isPrint(s) {
316		return strconv.Quote(s)
317	}
318
319	return s
320}
321
322func quoteIfNeededV(s []string) []string {
323	ret := make([]string, len(s))
324
325	for i, v := range s {
326		ret[i] = quoteIfNeeded(v)
327	}
328
329	return ret
330}
331
332func quoteV(s []string) []string {
333	ret := make([]string, len(s))
334
335	for i, v := range s {
336		ret[i] = strconv.Quote(v)
337	}
338
339	return ret
340}
341
342func unquoteIfPossible(s string) (string, error) {
343	if len(s) == 0 || s[0] != '"' {
344		return s, nil
345	}
346
347	return strconv.Unquote(s)
348}
349