1package dynamodbattribute
2
3import (
4	"reflect"
5	"sort"
6	"strings"
7)
8
9type field struct {
10	tag
11
12	Name        string
13	NameFromTag bool
14
15	Index []int
16	Type  reflect.Type
17}
18
19type cachedFields struct {
20	fields       []field
21	fieldsByName map[string]int
22}
23
24func (f *cachedFields) All() []field {
25	return f.fields
26}
27
28func (f *cachedFields) FieldByName(name string) (field, bool) {
29	if i, ok := f.fieldsByName[name]; ok {
30		return f.fields[i], ok
31	}
32	for _, f := range f.fields {
33		if strings.EqualFold(f.Name, name) {
34			return f, true
35		}
36	}
37	return field{}, false
38}
39
40func buildField(pIdx []int, i int, sf reflect.StructField, fieldTag tag) field {
41	f := field{
42		Name: sf.Name,
43		Type: sf.Type,
44		tag:  fieldTag,
45	}
46	if len(fieldTag.Name) != 0 {
47		f.NameFromTag = true
48		f.Name = fieldTag.Name
49	}
50
51	f.Index = make([]int, len(pIdx)+1)
52	copy(f.Index, pIdx)
53	f.Index[len(pIdx)] = i
54
55	return f
56}
57
58// unionStructFields returns a list of fields for the given type. Type info is cached
59// to avoid repeated calls into the reflect package
60func unionStructFields(t reflect.Type, opts MarshalOptions) *cachedFields {
61	if cached, ok := fieldCache.Load(t); ok {
62		return cached
63	}
64
65	f := enumFields(t, opts)
66	sort.Sort(fieldsByName(f))
67	f = visibleFields(f)
68
69	fs := &cachedFields{
70		fields:       f,
71		fieldsByName: make(map[string]int, len(f)),
72	}
73	for i, f := range fs.fields {
74		fs.fieldsByName[f.Name] = i
75	}
76
77	cached, _ := fieldCache.LoadOrStore(t, fs)
78	return cached
79}
80
81// enumFields will recursively iterate through a structure and its nested
82// anonymous fields.
83//
84// Based on the enoding/json struct field enumeration of the Go Stdlib
85// https://golang.org/src/encoding/json/encode.go typeField func.
86func enumFields(t reflect.Type, opts MarshalOptions) []field {
87	// Fields to explore
88	current := []field{}
89	next := []field{{Type: t}}
90
91	// count of queued names
92	count := map[reflect.Type]int{}
93	nextCount := map[reflect.Type]int{}
94
95	visited := map[reflect.Type]struct{}{}
96	fields := []field{}
97
98	for len(next) > 0 {
99		current, next = next, current[:0]
100		count, nextCount = nextCount, map[reflect.Type]int{}
101
102		for _, f := range current {
103			if _, ok := visited[f.Type]; ok {
104				continue
105			}
106			visited[f.Type] = struct{}{}
107
108			for i := 0; i < f.Type.NumField(); i++ {
109				sf := f.Type.Field(i)
110				if sf.PkgPath != "" && !sf.Anonymous {
111					// Ignore unexported and non-anonymous fields
112					// unexported but anonymous field may still be used if
113					// the type has exported nested fields
114					continue
115				}
116
117				fieldTag := tag{}
118				fieldTag.parseAVTag(sf.Tag)
119				// Because MarshalOptions.TagKey must be explicitly set, use it
120				// over JSON, which is enabled by default.
121				if opts.TagKey != "" && fieldTag == (tag{}) {
122					fieldTag.parseStructTag(opts.TagKey, sf.Tag)
123				} else if opts.SupportJSONTags && fieldTag == (tag{}) {
124					fieldTag.parseStructTag("json", sf.Tag)
125				}
126
127				if fieldTag.Ignore {
128					continue
129				}
130
131				ft := sf.Type
132				if ft.Name() == "" && ft.Kind() == reflect.Ptr {
133					ft = ft.Elem()
134				}
135
136				structField := buildField(f.Index, i, sf, fieldTag)
137				structField.Type = ft
138
139				if !sf.Anonymous || ft.Kind() != reflect.Struct {
140					fields = append(fields, structField)
141					if count[f.Type] > 1 {
142						// If there were multiple instances, add a second,
143						// so that the annihilation code will see a duplicate.
144						// It only cares about the distinction between 1 or 2,
145						// so don't bother generating any more copies.
146						fields = append(fields, structField)
147					}
148					continue
149				}
150
151				// Record new anon struct to explore next round
152				nextCount[ft]++
153				if nextCount[ft] == 1 {
154					next = append(next, structField)
155				}
156			}
157		}
158	}
159
160	return fields
161}
162
163// visibleFields will return a slice of fields which are visible based on
164// Go's standard visiblity rules with the exception of ties being broken
165// by depth and struct tag naming.
166//
167// Based on the enoding/json field filtering of the Go Stdlib
168// https://golang.org/src/encoding/json/encode.go typeField func.
169func visibleFields(fields []field) []field {
170	// Delete all fields that are hidden by the Go rules for embedded fields,
171	// except that fields with JSON tags are promoted.
172
173	// The fields are sorted in primary order of name, secondary order
174	// of field index length. Loop over names; for each name, delete
175	// hidden fields by choosing the one dominant field that survives.
176	out := fields[:0]
177	for advance, i := 0, 0; i < len(fields); i += advance {
178		// One iteration per name.
179		// Find the sequence of fields with the name of this first field.
180		fi := fields[i]
181		name := fi.Name
182		for advance = 1; i+advance < len(fields); advance++ {
183			fj := fields[i+advance]
184			if fj.Name != name {
185				break
186			}
187		}
188		if advance == 1 { // Only one field with this name
189			out = append(out, fi)
190			continue
191		}
192		dominant, ok := dominantField(fields[i : i+advance])
193		if ok {
194			out = append(out, dominant)
195		}
196	}
197
198	fields = out
199	sort.Sort(fieldsByIndex(fields))
200
201	return fields
202}
203
204// dominantField looks through the fields, all of which are known to
205// have the same name, to find the single field that dominates the
206// others using Go's embedding rules, modified by the presence of
207// JSON tags. If there are multiple top-level fields, the boolean
208// will be false: This condition is an error in Go and we skip all
209// the fields.
210//
211// Based on the enoding/json field filtering of the Go Stdlib
212// https://golang.org/src/encoding/json/encode.go dominantField func.
213func dominantField(fields []field) (field, bool) {
214	// The fields are sorted in increasing index-length order. The winner
215	// must therefore be one with the shortest index length. Drop all
216	// longer entries, which is easy: just truncate the slice.
217	length := len(fields[0].Index)
218	tagged := -1 // Index of first tagged field.
219	for i, f := range fields {
220		if len(f.Index) > length {
221			fields = fields[:i]
222			break
223		}
224		if f.NameFromTag {
225			if tagged >= 0 {
226				// Multiple tagged fields at the same level: conflict.
227				// Return no field.
228				return field{}, false
229			}
230			tagged = i
231		}
232	}
233	if tagged >= 0 {
234		return fields[tagged], true
235	}
236	// All remaining fields have the same length. If there's more than one,
237	// we have a conflict (two fields named "X" at the same level) and we
238	// return no field.
239	if len(fields) > 1 {
240		return field{}, false
241	}
242	return fields[0], true
243}
244
245// fieldsByName sorts field by name, breaking ties with depth,
246// then breaking ties with "name came from json tag", then
247// breaking ties with index sequence.
248//
249// Based on the enoding/json field filtering of the Go Stdlib
250// https://golang.org/src/encoding/json/encode.go fieldsByName type.
251type fieldsByName []field
252
253func (x fieldsByName) Len() int { return len(x) }
254
255func (x fieldsByName) Swap(i, j int) { x[i], x[j] = x[j], x[i] }
256
257func (x fieldsByName) Less(i, j int) bool {
258	if x[i].Name != x[j].Name {
259		return x[i].Name < x[j].Name
260	}
261	if len(x[i].Index) != len(x[j].Index) {
262		return len(x[i].Index) < len(x[j].Index)
263	}
264	if x[i].NameFromTag != x[j].NameFromTag {
265		return x[i].NameFromTag
266	}
267	return fieldsByIndex(x).Less(i, j)
268}
269
270// fieldsByIndex sorts field by index sequence.
271//
272// Based on the enoding/json field filtering of the Go Stdlib
273// https://golang.org/src/encoding/json/encode.go fieldsByIndex type.
274type fieldsByIndex []field
275
276func (x fieldsByIndex) Len() int { return len(x) }
277
278func (x fieldsByIndex) Swap(i, j int) { x[i], x[j] = x[j], x[i] }
279
280func (x fieldsByIndex) Less(i, j int) bool {
281	for k, xik := range x[i].Index {
282		if k >= len(x[j].Index) {
283			return false
284		}
285		if xik != x[j].Index[k] {
286			return xik < x[j].Index[k]
287		}
288	}
289	return len(x[i].Index) < len(x[j].Index)
290}
291