1// Copyright 2016 Google LLC
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//      http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15// Package fields provides a view of the fields of a struct that follows the Go
16// rules, amended to consider tags and case insensitivity.
17//
18// Usage
19//
20// First define a function that interprets tags:
21//
22//   func parseTag(st reflect.StructTag) (name string, keep bool, other interface{}, err error) { ... }
23//
24// The function's return values describe whether to ignore the field
25// completely or provide an alternate name, as well as other data from the
26// parse that is stored to avoid re-parsing.
27//
28// Then define a function to validate the type:
29//
30//   func validate(t reflect.Type) error { ... }
31//
32// Then, if necessary, define a function to specify leaf types - types
33// which should be considered one field and not be recursed into:
34//
35//   func isLeafType(t reflect.Type) bool { ... }
36//
37// eg:
38//
39//   func isLeafType(t reflect.Type) bool {
40//      return t == reflect.TypeOf(time.Time{})
41//   }
42//
43// Next, construct a Cache, passing your functions. As its name suggests, a
44// Cache remembers validation and field information for a type, so subsequent
45// calls with the same type are very fast.
46//
47//    cache := fields.NewCache(parseTag, validate, isLeafType)
48//
49// To get the fields of a struct type as determined by the above rules, call
50// the Fields method:
51//
52//    fields, err := cache.Fields(reflect.TypeOf(MyStruct{}))
53//
54// The return value can be treated as a slice of Fields.
55//
56// Given a string, such as a key or column name obtained during unmarshalling,
57// call Match on the list of fields to find a field whose name is the best
58// match:
59//
60//   field := fields.Match(name)
61//
62// Match looks for an exact match first, then falls back to a case-insensitive
63// comparison.
64package fields
65
66import (
67	"bytes"
68	"errors"
69	"reflect"
70	"sort"
71	"strings"
72	"sync"
73)
74
75// A Field records information about a struct field.
76type Field struct {
77	Name        string       // effective field name
78	NameFromTag bool         // did Name come from a tag?
79	Type        reflect.Type // field type
80	Index       []int        // index sequence, for reflect.Value.FieldByIndex
81	ParsedTag   interface{}  // third return value of the parseTag function
82
83	nameBytes []byte
84	equalFold func(s, t []byte) bool
85}
86
87// ParseTagFunc is a function that accepts a struct tag and returns four values: an alternative name for the field
88// extracted from the tag, a boolean saying whether to keep the field or ignore  it, additional data that is stored
89// with the field information to avoid having to parse the tag again, and an error.
90type ParseTagFunc func(reflect.StructTag) (name string, keep bool, other interface{}, err error)
91
92// ValidateFunc is a function that accepts a reflect.Type and returns an error if the struct type is invalid in any
93// way.
94type ValidateFunc func(reflect.Type) error
95
96// LeafTypesFunc is a function that accepts a reflect.Type and returns true if the struct type a leaf, or false if not.
97// TODO(deklerk): is this description accurate?
98type LeafTypesFunc func(reflect.Type) bool
99
100// A Cache records information about the fields of struct types.
101//
102// A Cache is safe for use by multiple goroutines.
103type Cache struct {
104	parseTag  ParseTagFunc
105	validate  ValidateFunc
106	leafTypes LeafTypesFunc
107	cache     sync.Map // from reflect.Type to cacheValue
108}
109
110// NewCache constructs a Cache.
111//
112// Its first argument should be a function that accepts
113// a struct tag and returns four values: an alternative name for the field
114// extracted from the tag, a boolean saying whether to keep the field or ignore
115// it, additional data that is stored with the field information to avoid
116// having to parse the tag again, and an error.
117//
118// Its second argument should be a function that accepts a reflect.Type and
119// returns an error if the struct type is invalid in any way. For example, it
120// may check that all of the struct field tags are valid, or that all fields
121// are of an appropriate type.
122func NewCache(parseTag ParseTagFunc, validate ValidateFunc, leafTypes LeafTypesFunc) *Cache {
123	if parseTag == nil {
124		parseTag = func(reflect.StructTag) (string, bool, interface{}, error) {
125			return "", true, nil, nil
126		}
127	}
128	if validate == nil {
129		validate = func(reflect.Type) error {
130			return nil
131		}
132	}
133	if leafTypes == nil {
134		leafTypes = func(reflect.Type) bool {
135			return false
136		}
137	}
138
139	return &Cache{
140		parseTag:  parseTag,
141		validate:  validate,
142		leafTypes: leafTypes,
143	}
144}
145
146// A fieldScan represents an item on the fieldByNameFunc scan work list.
147type fieldScan struct {
148	typ   reflect.Type
149	index []int
150}
151
152// Fields returns all the exported fields of t, which must be a struct type. It
153// follows the standard Go rules for embedded fields, modified by the presence
154// of tags. The result is sorted lexicographically by index.
155//
156// These rules apply in the absence of tags:
157// Anonymous struct fields are treated as if their inner exported fields were
158// fields in the outer struct (embedding). The result includes all fields that
159// aren't shadowed by fields at higher level of embedding. If more than one
160// field with the same name exists at the same level of embedding, it is
161// excluded. An anonymous field that is not of struct type is treated as having
162// its type as its name.
163//
164// Tags modify these rules as follows:
165// A field's tag is used as its name.
166// An anonymous struct field with a name given in its tag is treated as
167// a field having that name, rather than an embedded struct (the struct's
168// fields will not be returned).
169// If more than one field with the same name exists at the same level of embedding,
170// but exactly one of them is tagged, then the tagged field is reported and the others
171// are ignored.
172func (c *Cache) Fields(t reflect.Type) (List, error) {
173	if t.Kind() != reflect.Struct {
174		panic("fields: Fields of non-struct type")
175	}
176	return c.cachedTypeFields(t)
177}
178
179// A List is a list of Fields.
180type List []Field
181
182// Match returns the field in the list whose name best matches the supplied
183// name, nor nil if no field does. If there is a field with the exact name, it
184// is returned. Otherwise the first field (sorted by index) whose name matches
185// case-insensitively is returned.
186func (l List) Match(name string) *Field {
187	return l.MatchBytes([]byte(name))
188}
189
190// MatchBytes is identical to Match, except that the argument is a byte slice.
191func (l List) MatchBytes(name []byte) *Field {
192	var f *Field
193	for i := range l {
194		ff := &l[i]
195		if bytes.Equal(ff.nameBytes, name) {
196			return ff
197		}
198		if f == nil && ff.equalFold(ff.nameBytes, name) {
199			f = ff
200		}
201	}
202	return f
203}
204
205type cacheValue struct {
206	fields List
207	err    error
208}
209
210// cachedTypeFields is like typeFields but uses a cache to avoid repeated work.
211// This code has been copied and modified from
212// https://go.googlesource.com/go/+/go1.7.3/src/encoding/json/encode.go.
213func (c *Cache) cachedTypeFields(t reflect.Type) (List, error) {
214	var cv cacheValue
215	x, ok := c.cache.Load(t)
216	if ok {
217		cv = x.(cacheValue)
218	} else {
219		if err := c.validate(t); err != nil {
220			cv = cacheValue{nil, err}
221		} else {
222			f, err := c.typeFields(t)
223			cv = cacheValue{List(f), err}
224		}
225		c.cache.Store(t, cv)
226	}
227	return cv.fields, cv.err
228}
229
230func (c *Cache) typeFields(t reflect.Type) ([]Field, error) {
231	fields, err := c.listFields(t)
232	if err != nil {
233		return nil, err
234	}
235	sort.Sort(byName(fields))
236	// Delete all fields that are hidden by the Go rules for embedded fields.
237
238	// The fields are sorted in primary order of name, secondary order of field
239	// index length. So the first field with a given name is the dominant one.
240	var out []Field
241	for advance, i := 0, 0; i < len(fields); i += advance {
242		// One iteration per name.
243		// Find the sequence of fields with the name of this first field.
244		fi := fields[i]
245		name := fi.Name
246		for advance = 1; i+advance < len(fields); advance++ {
247			fj := fields[i+advance]
248			if fj.Name != name {
249				break
250			}
251		}
252		// Find the dominant field, if any, out of all fields that have the same name.
253		dominant, ok := dominantField(fields[i : i+advance])
254		if ok {
255			out = append(out, dominant)
256		}
257	}
258	sort.Sort(byIndex(out))
259	return out, nil
260}
261
262func (c *Cache) listFields(t reflect.Type) ([]Field, error) {
263	// This uses the same condition that the Go language does: there must be a unique instance
264	// of the match at a given depth level. If there are multiple instances of a match at the
265	// same depth, they annihilate each other and inhibit any possible match at a lower level.
266	// The algorithm is breadth first search, one depth level at a time.
267
268	// The current and next slices are work queues:
269	// current lists the fields to visit on this depth level,
270	// and next lists the fields on the next lower level.
271	current := []fieldScan{}
272	next := []fieldScan{{typ: t}}
273
274	// nextCount records the number of times an embedded type has been
275	// encountered and considered for queueing in the 'next' slice.
276	// We only queue the first one, but we increment the count on each.
277	// If a struct type T can be reached more than once at a given depth level,
278	// then it annihilates itself and need not be considered at all when we
279	// process that next depth level.
280	var nextCount map[reflect.Type]int
281
282	// visited records the structs that have been considered already.
283	// Embedded pointer fields can create cycles in the graph of
284	// reachable embedded types; visited avoids following those cycles.
285	// It also avoids duplicated effort: if we didn't find the field in an
286	// embedded type T at level 2, we won't find it in one at level 4 either.
287	visited := map[reflect.Type]bool{}
288
289	var fields []Field // Fields found.
290
291	for len(next) > 0 {
292		current, next = next, current[:0]
293		count := nextCount
294		nextCount = nil
295
296		// Process all the fields at this depth, now listed in 'current'.
297		// The loop queues embedded fields found in 'next', for processing during the next
298		// iteration. The multiplicity of the 'current' field counts is recorded
299		// in 'count'; the multiplicity of the 'next' field counts is recorded in 'nextCount'.
300		for _, scan := range current {
301			t := scan.typ
302			if visited[t] {
303				// We've looked through this type before, at a higher level.
304				// That higher level would shadow the lower level we're now at,
305				// so this one can't be useful to us. Ignore it.
306				continue
307			}
308			visited[t] = true
309			for i := 0; i < t.NumField(); i++ {
310				f := t.Field(i)
311
312				exported := (f.PkgPath == "")
313
314				// If a named field is unexported, ignore it. An anonymous
315				// unexported field is processed, because it may contain
316				// exported fields, which are visible.
317				if !exported && !f.Anonymous {
318					continue
319				}
320
321				// Examine the tag.
322				tagName, keep, other, err := c.parseTag(f.Tag)
323				if err != nil {
324					return nil, err
325				}
326				if !keep {
327					continue
328				}
329				if c.leafTypes(f.Type) {
330					fields = append(fields, newField(f, tagName, other, scan.index, i))
331					continue
332				}
333
334				var ntyp reflect.Type
335				if f.Anonymous {
336					// Anonymous field of type T or *T.
337					ntyp = f.Type
338					if ntyp.Kind() == reflect.Ptr {
339						ntyp = ntyp.Elem()
340					}
341				}
342
343				// Record fields with a tag name, non-anonymous fields, or
344				// anonymous non-struct fields.
345				if tagName != "" || ntyp == nil || ntyp.Kind() != reflect.Struct {
346					if !exported {
347						continue
348					}
349					fields = append(fields, newField(f, tagName, other, scan.index, i))
350					if count[t] > 1 {
351						// If there were multiple instances, add a second,
352						// so that the annihilation code will see a duplicate.
353						fields = append(fields, fields[len(fields)-1])
354					}
355					continue
356				}
357
358				// Queue embedded struct fields for processing with next level,
359				// but only if the embedded types haven't already been queued.
360				if nextCount[ntyp] > 0 {
361					nextCount[ntyp] = 2 // exact multiple doesn't matter
362					continue
363				}
364				if nextCount == nil {
365					nextCount = map[reflect.Type]int{}
366				}
367				nextCount[ntyp] = 1
368				if count[t] > 1 {
369					nextCount[ntyp] = 2 // exact multiple doesn't matter
370				}
371				var index []int
372				index = append(index, scan.index...)
373				index = append(index, i)
374				next = append(next, fieldScan{ntyp, index})
375			}
376		}
377	}
378	return fields, nil
379}
380
381func newField(f reflect.StructField, tagName string, other interface{}, index []int, i int) Field {
382	name := tagName
383	if name == "" {
384		name = f.Name
385	}
386	sf := Field{
387		Name:        name,
388		NameFromTag: tagName != "",
389		Type:        f.Type,
390		ParsedTag:   other,
391		nameBytes:   []byte(name),
392	}
393	sf.equalFold = foldFunc(sf.nameBytes)
394	sf.Index = append(sf.Index, index...)
395	sf.Index = append(sf.Index, i)
396	return sf
397}
398
399// byName sorts fields using the following criteria, in order:
400// 1. name
401// 2. embedding depth
402// 3. tag presence (preferring a tagged field)
403// 4. index sequence.
404type byName []Field
405
406func (x byName) Len() int { return len(x) }
407
408func (x byName) Swap(i, j int) { x[i], x[j] = x[j], x[i] }
409
410func (x byName) Less(i, j int) bool {
411	if x[i].Name != x[j].Name {
412		return x[i].Name < x[j].Name
413	}
414	if len(x[i].Index) != len(x[j].Index) {
415		return len(x[i].Index) < len(x[j].Index)
416	}
417	if x[i].NameFromTag != x[j].NameFromTag {
418		return x[i].NameFromTag
419	}
420	return byIndex(x).Less(i, j)
421}
422
423// byIndex sorts field by index sequence.
424type byIndex []Field
425
426func (x byIndex) Len() int { return len(x) }
427
428func (x byIndex) Swap(i, j int) { x[i], x[j] = x[j], x[i] }
429
430func (x byIndex) Less(i, j int) bool {
431	xi := x[i].Index
432	xj := x[j].Index
433	ln := len(xi)
434	if l := len(xj); l < ln {
435		ln = l
436	}
437	for k := 0; k < ln; k++ {
438		if xi[k] != xj[k] {
439			return xi[k] < xj[k]
440		}
441	}
442	return len(xi) < len(xj)
443}
444
445// dominantField looks through the fields, all of which are known to have the
446// same name, to find the single field that dominates the others using Go's
447// embedding rules, modified by the presence of tags. If there are multiple
448// top-level fields, the boolean will be false: This condition is an error in
449// Go and we skip all the fields.
450func dominantField(fs []Field) (Field, bool) {
451	// The fields are sorted in increasing index-length order, then by presence of tag.
452	// That means that the first field is the dominant one. We need only check
453	// for error cases: two fields at top level, either both tagged or neither tagged.
454	if len(fs) > 1 && len(fs[0].Index) == len(fs[1].Index) && fs[0].NameFromTag == fs[1].NameFromTag {
455		return Field{}, false
456	}
457	return fs[0], true
458}
459
460// ParseStandardTag extracts the sub-tag named by key, then parses it using the
461// de facto standard format introduced in encoding/json:
462//   "-" means "ignore this tag". It must occur by itself. (parseStandardTag returns an error
463//       in this case, whereas encoding/json accepts the "-" even if it is not alone.)
464//   "<name>" provides an alternative name for the field
465//   "<name>,opt1,opt2,..." specifies options after the name.
466// The options are returned as a []string.
467func ParseStandardTag(key string, t reflect.StructTag) (name string, keep bool, options []string, err error) {
468	s := t.Get(key)
469	parts := strings.Split(s, ",")
470	if parts[0] == "-" {
471		if len(parts) > 1 {
472			return "", false, nil, errors.New(`"-" field tag with options`)
473		}
474		return "", false, nil, nil
475	}
476	if len(parts) > 1 {
477		options = parts[1:]
478	}
479	return parts[0], true, options, nil
480}
481