1package structs
2
3import (
4	"reflect"
5	"sort"
6	"strconv"
7
8	"github.com/spf13/cast"
9	"github.com/zeebo/errs"
10)
11
12// Result contains information about the result of a Decode.
13type Result struct {
14	Error   error
15	Used    map[string]struct{}
16	Missing map[string]struct{}
17	Broken  map[string]struct{}
18}
19
20// DecodeOption controls the operation of a Decode.
21type DecodeOption interface {
22	private()
23
24	apply(*decodeState)
25}
26
27// decodeOptionFunc is an implementation of DecodeOption for a func
28type decodeOptionFunc func(*decodeState)
29
30func (o decodeOptionFunc) private()              {}
31func (o decodeOptionFunc) apply(ds *decodeState) { o(ds) }
32
33// Decode takes values out of input and stores them into output, allocating as necessary.
34func Decode(input map[string]interface{}, output interface{}, opts ...DecodeOption) Result {
35	var ds decodeState
36	for _, opt := range opts {
37		opt.apply(&ds)
38	}
39	ds.decode(input, reflect.ValueOf(output), "")
40	return ds.res
41}
42
43// decodeState keeps state across recursive calls to decode.
44type decodeState struct {
45	res Result
46}
47
48// decodeKeyValue decodes into output the value after walking through fields/indexing as described
49// by key. It returns true if anything was set. The base is the path the output is at with respect
50// to the top most decode.
51func (d *decodeState) decodeKeyValue(key string, value interface{}, output reflect.Value, base string) bool {
52	nextBase := dotJoin(base, key)
53
54	var rw reflectWalker
55	field, err := rw.Walk(output, key)
56	if err != nil {
57		d.res.Broken = gatherKeys(value, nextBase, d.res.Broken)
58		d.res.Error = errs.Combine(d.res.Error, err)
59		return false
60	}
61	if !field.IsValid() {
62		d.res.Missing = gatherKeys(value, nextBase, d.res.Missing)
63		return false
64	}
65
66	if d.decode(value, field, nextBase) {
67		rw.Commit()
68		return true
69	}
70	return false
71}
72
73// decode looks at the type of input and dispatches to helper routines to decode the input into
74// the output. It returns true if anything was set.
75func (d *decodeState) decode(input interface{}, output reflect.Value, base string) bool {
76	switch input := input.(type) {
77	case map[string]interface{}:
78		// Go through the keys in sorted order to avoid randomness
79		keys := make([]string, 0, len(input))
80		for key := range input {
81			keys = append(keys, key)
82		}
83		sort.Strings(keys)
84
85		any := false
86		for _, key := range keys {
87			any = d.decodeKeyValue(key, input[key], output, base) || any
88		}
89		return any
90
91	case map[interface{}]interface{}:
92		// Filter out the string keys and go through them in sorted order to avoid randomness
93		keys := make([]string, 0, len(input))
94		for key := range input {
95			if key, err := cast.ToStringE(key); err == nil {
96				keys = append(keys, key)
97			}
98		}
99		sort.Strings(keys)
100
101		any := false
102		for _, key := range keys {
103			any = d.decodeKeyValue(key, input[key], output, base) || any
104		}
105		return any
106
107	case []interface{}:
108		any := false
109		for key, value := range input {
110			any = d.decodeKeyValue(strconv.Itoa(key), value, output, base) || any
111		}
112		return any
113
114	default:
115		set, err := setValue(output, input)
116		if !set || err != nil {
117			d.res.Broken = gatherKeys(input, base, d.res.Broken)
118			d.res.Error = errs.Combine(d.res.Error, err)
119		} else if set {
120			if d.res.Used == nil {
121				d.res.Used = make(map[string]struct{})
122			}
123			d.res.Used[base] = struct{}{}
124		}
125		return set
126	}
127}
128