1package structure
2
3// references: https://github.com/mitchellh/mapstructure
4
5import (
6	"fmt"
7	"reflect"
8	"strconv"
9	"strings"
10)
11
12// Option is the configuration that is used to create a new decoder
13type Option struct {
14	TagName          string
15	WeaklyTypedInput bool
16}
17
18// Decoder is the core of structure
19type Decoder struct {
20	option *Option
21}
22
23// NewDecoder return a Decoder by Option
24func NewDecoder(option Option) *Decoder {
25	if option.TagName == "" {
26		option.TagName = "structure"
27	}
28	return &Decoder{option: &option}
29}
30
31// Decode transform a map[string]interface{} to a struct
32func (d *Decoder) Decode(src map[string]interface{}, dst interface{}) error {
33	if reflect.TypeOf(dst).Kind() != reflect.Ptr {
34		return fmt.Errorf("Decode must recive a ptr struct")
35	}
36	t := reflect.TypeOf(dst).Elem()
37	v := reflect.ValueOf(dst).Elem()
38	for idx := 0; idx < v.NumField(); idx++ {
39		field := t.Field(idx)
40
41		tag := field.Tag.Get(d.option.TagName)
42		str := strings.SplitN(tag, ",", 2)
43		key := str[0]
44		omitempty := false
45		if len(str) > 1 {
46			omitempty = str[1] == "omitempty"
47		}
48
49		value, ok := src[key]
50		if !ok || value == nil {
51			if omitempty {
52				continue
53			}
54			return fmt.Errorf("key '%s' missing", key)
55		}
56
57		err := d.decode(key, value, v.Field(idx))
58		if err != nil {
59			return err
60		}
61	}
62	return nil
63}
64
65func (d *Decoder) decode(name string, data interface{}, val reflect.Value) error {
66	switch val.Kind() {
67	case reflect.Int:
68		return d.decodeInt(name, data, val)
69	case reflect.String:
70		return d.decodeString(name, data, val)
71	case reflect.Bool:
72		return d.decodeBool(name, data, val)
73	case reflect.Slice:
74		return d.decodeSlice(name, data, val)
75	case reflect.Map:
76		return d.decodeMap(name, data, val)
77	case reflect.Interface:
78		return d.setInterface(name, data, val)
79	case reflect.Struct:
80		return d.decodeStruct(name, data, val)
81	default:
82		return fmt.Errorf("type %s not support", val.Kind().String())
83	}
84}
85
86func (d *Decoder) decodeInt(name string, data interface{}, val reflect.Value) (err error) {
87	dataVal := reflect.ValueOf(data)
88	kind := dataVal.Kind()
89	switch {
90	case kind == reflect.Int:
91		val.SetInt(dataVal.Int())
92	case kind == reflect.String && d.option.WeaklyTypedInput:
93		var i int64
94		i, err = strconv.ParseInt(dataVal.String(), 0, val.Type().Bits())
95		if err == nil {
96			val.SetInt(i)
97		} else {
98			err = fmt.Errorf("cannot parse '%s' as int: %s", name, err)
99		}
100	default:
101		err = fmt.Errorf(
102			"'%s' expected type '%s', got unconvertible type '%s'",
103			name, val.Type(), dataVal.Type(),
104		)
105	}
106	return err
107}
108
109func (d *Decoder) decodeString(name string, data interface{}, val reflect.Value) (err error) {
110	dataVal := reflect.ValueOf(data)
111	kind := dataVal.Kind()
112	switch {
113	case kind == reflect.String:
114		val.SetString(dataVal.String())
115	case kind == reflect.Int && d.option.WeaklyTypedInput:
116		val.SetString(strconv.FormatInt(dataVal.Int(), 10))
117	default:
118		err = fmt.Errorf(
119			"'%s' expected type '%s', got unconvertible type '%s'",
120			name, val.Type(), dataVal.Type(),
121		)
122	}
123	return err
124}
125
126func (d *Decoder) decodeBool(name string, data interface{}, val reflect.Value) (err error) {
127	dataVal := reflect.ValueOf(data)
128	kind := dataVal.Kind()
129	switch {
130	case kind == reflect.Bool:
131		val.SetBool(dataVal.Bool())
132	case kind == reflect.Int && d.option.WeaklyTypedInput:
133		val.SetBool(dataVal.Int() != 0)
134	default:
135		err = fmt.Errorf(
136			"'%s' expected type '%s', got unconvertible type '%s'",
137			name, val.Type(), dataVal.Type(),
138		)
139	}
140	return err
141}
142
143func (d *Decoder) decodeSlice(name string, data interface{}, val reflect.Value) error {
144	dataVal := reflect.Indirect(reflect.ValueOf(data))
145	valType := val.Type()
146	valElemType := valType.Elem()
147
148	if dataVal.Kind() != reflect.Slice {
149		return fmt.Errorf("'%s' is not a slice", name)
150	}
151
152	valSlice := val
153	for i := 0; i < dataVal.Len(); i++ {
154		currentData := dataVal.Index(i).Interface()
155		for valSlice.Len() <= i {
156			valSlice = reflect.Append(valSlice, reflect.Zero(valElemType))
157		}
158		currentField := valSlice.Index(i)
159
160		fieldName := fmt.Sprintf("%s[%d]", name, i)
161		if err := d.decode(fieldName, currentData, currentField); err != nil {
162			return err
163		}
164	}
165
166	val.Set(valSlice)
167	return nil
168}
169
170func (d *Decoder) decodeMap(name string, data interface{}, val reflect.Value) error {
171	valType := val.Type()
172	valKeyType := valType.Key()
173	valElemType := valType.Elem()
174
175	valMap := val
176
177	if valMap.IsNil() {
178		mapType := reflect.MapOf(valKeyType, valElemType)
179		valMap = reflect.MakeMap(mapType)
180	}
181
182	dataVal := reflect.Indirect(reflect.ValueOf(data))
183	if dataVal.Kind() != reflect.Map {
184		return fmt.Errorf("'%s' expected a map, got '%s'", name, dataVal.Kind())
185	}
186
187	return d.decodeMapFromMap(name, dataVal, val, valMap)
188}
189
190func (d *Decoder) decodeMapFromMap(name string, dataVal reflect.Value, val reflect.Value, valMap reflect.Value) error {
191	valType := val.Type()
192	valKeyType := valType.Key()
193	valElemType := valType.Elem()
194
195	errors := make([]string, 0)
196
197	if dataVal.Len() == 0 {
198		if dataVal.IsNil() {
199			if !val.IsNil() {
200				val.Set(dataVal)
201			}
202		} else {
203			val.Set(valMap)
204		}
205
206		return nil
207	}
208
209	for _, k := range dataVal.MapKeys() {
210		fieldName := fmt.Sprintf("%s[%s]", name, k)
211
212		currentKey := reflect.Indirect(reflect.New(valKeyType))
213		if err := d.decode(fieldName, k.Interface(), currentKey); err != nil {
214			errors = append(errors, err.Error())
215			continue
216		}
217
218		v := dataVal.MapIndex(k).Interface()
219		if v == nil {
220			errors = append(errors, fmt.Sprintf("filed %s invalid", fieldName))
221			continue
222		}
223
224		currentVal := reflect.Indirect(reflect.New(valElemType))
225		if err := d.decode(fieldName, v, currentVal); err != nil {
226			errors = append(errors, err.Error())
227			continue
228		}
229
230		valMap.SetMapIndex(currentKey, currentVal)
231	}
232
233	val.Set(valMap)
234
235	if len(errors) > 0 {
236		return fmt.Errorf(strings.Join(errors, ","))
237	}
238
239	return nil
240}
241
242func (d *Decoder) decodeStruct(name string, data interface{}, val reflect.Value) error {
243	dataVal := reflect.Indirect(reflect.ValueOf(data))
244
245	// If the type of the value to write to and the data match directly,
246	// then we just set it directly instead of recursing into the structure.
247	if dataVal.Type() == val.Type() {
248		val.Set(dataVal)
249		return nil
250	}
251
252	dataValKind := dataVal.Kind()
253	switch dataValKind {
254	case reflect.Map:
255		return d.decodeStructFromMap(name, dataVal, val)
256	default:
257		return fmt.Errorf("'%s' expected a map, got '%s'", name, dataVal.Kind())
258	}
259}
260
261func (d *Decoder) decodeStructFromMap(name string, dataVal, val reflect.Value) error {
262	dataValType := dataVal.Type()
263	if kind := dataValType.Key().Kind(); kind != reflect.String && kind != reflect.Interface {
264		return fmt.Errorf(
265			"'%s' needs a map with string keys, has '%s' keys",
266			name, dataValType.Key().Kind())
267	}
268
269	dataValKeys := make(map[reflect.Value]struct{})
270	dataValKeysUnused := make(map[interface{}]struct{})
271	for _, dataValKey := range dataVal.MapKeys() {
272		dataValKeys[dataValKey] = struct{}{}
273		dataValKeysUnused[dataValKey.Interface()] = struct{}{}
274	}
275
276	errors := make([]string, 0)
277
278	// This slice will keep track of all the structs we'll be decoding.
279	// There can be more than one struct if there are embedded structs
280	// that are squashed.
281	structs := make([]reflect.Value, 1, 5)
282	structs[0] = val
283
284	// Compile the list of all the fields that we're going to be decoding
285	// from all the structs.
286	type field struct {
287		field reflect.StructField
288		val   reflect.Value
289	}
290	fields := []field{}
291	for len(structs) > 0 {
292		structVal := structs[0]
293		structs = structs[1:]
294
295		structType := structVal.Type()
296
297		for i := 0; i < structType.NumField(); i++ {
298			fieldType := structType.Field(i)
299			fieldKind := fieldType.Type.Kind()
300
301			// If "squash" is specified in the tag, we squash the field down.
302			squash := false
303			tagParts := strings.Split(fieldType.Tag.Get(d.option.TagName), ",")
304			for _, tag := range tagParts[1:] {
305				if tag == "squash" {
306					squash = true
307					break
308				}
309			}
310
311			if squash {
312				if fieldKind != reflect.Struct {
313					errors = append(errors,
314						fmt.Errorf("%s: unsupported type for squash: %s", fieldType.Name, fieldKind).Error())
315				} else {
316					structs = append(structs, structVal.FieldByName(fieldType.Name))
317				}
318				continue
319			}
320
321			// Normal struct field, store it away
322			fields = append(fields, field{fieldType, structVal.Field(i)})
323		}
324	}
325
326	// for fieldType, field := range fields {
327	for _, f := range fields {
328		field, fieldValue := f.field, f.val
329		fieldName := field.Name
330
331		tagValue := field.Tag.Get(d.option.TagName)
332		tagValue = strings.SplitN(tagValue, ",", 2)[0]
333		if tagValue != "" {
334			fieldName = tagValue
335		}
336
337		rawMapKey := reflect.ValueOf(fieldName)
338		rawMapVal := dataVal.MapIndex(rawMapKey)
339		if !rawMapVal.IsValid() {
340			// Do a slower search by iterating over each key and
341			// doing case-insensitive search.
342			for dataValKey := range dataValKeys {
343				mK, ok := dataValKey.Interface().(string)
344				if !ok {
345					// Not a string key
346					continue
347				}
348
349				if strings.EqualFold(mK, fieldName) {
350					rawMapKey = dataValKey
351					rawMapVal = dataVal.MapIndex(dataValKey)
352					break
353				}
354			}
355
356			if !rawMapVal.IsValid() {
357				// There was no matching key in the map for the value in
358				// the struct. Just ignore.
359				continue
360			}
361		}
362
363		// Delete the key we're using from the unused map so we stop tracking
364		delete(dataValKeysUnused, rawMapKey.Interface())
365
366		if !fieldValue.IsValid() {
367			// This should never happen
368			panic("field is not valid")
369		}
370
371		// If we can't set the field, then it is unexported or something,
372		// and we just continue onwards.
373		if !fieldValue.CanSet() {
374			continue
375		}
376
377		// If the name is empty string, then we're at the root, and we
378		// don't dot-join the fields.
379		if name != "" {
380			fieldName = fmt.Sprintf("%s.%s", name, fieldName)
381		}
382
383		if err := d.decode(fieldName, rawMapVal.Interface(), fieldValue); err != nil {
384			errors = append(errors, err.Error())
385		}
386	}
387
388	if len(errors) > 0 {
389		return fmt.Errorf(strings.Join(errors, ","))
390	}
391
392	return nil
393}
394
395func (d *Decoder) setInterface(name string, data interface{}, val reflect.Value) (err error) {
396	dataVal := reflect.ValueOf(data)
397	val.Set(dataVal)
398	return nil
399}
400