1// Copyright 2013 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package yaml
6
7import (
8	"bytes"
9	"encoding"
10	"encoding/json"
11	"reflect"
12	"sort"
13	"strings"
14	"sync"
15	"unicode"
16	"unicode/utf8"
17)
18
19// indirect walks down v allocating pointers as needed,
20// until it gets to a non-pointer.
21// if it encounters an Unmarshaler, indirect stops and returns that.
22// if decodingNull is true, indirect stops at the last pointer so it can be set to nil.
23func indirect(v reflect.Value, decodingNull bool) (json.Unmarshaler, encoding.TextUnmarshaler, reflect.Value) {
24	// If v is a named type and is addressable,
25	// start with its address, so that if the type has pointer methods,
26	// we find them.
27	if v.Kind() != reflect.Ptr && v.Type().Name() != "" && v.CanAddr() {
28		v = v.Addr()
29	}
30	for {
31		// Load value from interface, but only if the result will be
32		// usefully addressable.
33		if v.Kind() == reflect.Interface && !v.IsNil() {
34			e := v.Elem()
35			if e.Kind() == reflect.Ptr && !e.IsNil() && (!decodingNull || e.Elem().Kind() == reflect.Ptr) {
36				v = e
37				continue
38			}
39		}
40
41		if v.Kind() != reflect.Ptr {
42			break
43		}
44
45		if v.Elem().Kind() != reflect.Ptr && decodingNull && v.CanSet() {
46			break
47		}
48		if v.IsNil() {
49			if v.CanSet() {
50				v.Set(reflect.New(v.Type().Elem()))
51			} else {
52				v = reflect.New(v.Type().Elem())
53			}
54		}
55		if v.Type().NumMethod() > 0 {
56			if u, ok := v.Interface().(json.Unmarshaler); ok {
57				return u, nil, reflect.Value{}
58			}
59			if u, ok := v.Interface().(encoding.TextUnmarshaler); ok {
60				return nil, u, reflect.Value{}
61			}
62		}
63		v = v.Elem()
64	}
65	return nil, nil, v
66}
67
68// A field represents a single field found in a struct.
69type field struct {
70	name      string
71	nameBytes []byte                 // []byte(name)
72	equalFold func(s, t []byte) bool // bytes.EqualFold or equivalent
73
74	tag       bool
75	index     []int
76	typ       reflect.Type
77	omitEmpty bool
78	quoted    bool
79}
80
81func fillField(f field) field {
82	f.nameBytes = []byte(f.name)
83	f.equalFold = foldFunc(f.nameBytes)
84	return f
85}
86
87// byName sorts field by name, breaking ties with depth,
88// then breaking ties with "name came from json tag", then
89// breaking ties with index sequence.
90type byName []field
91
92func (x byName) Len() int { return len(x) }
93
94func (x byName) Swap(i, j int) { x[i], x[j] = x[j], x[i] }
95
96func (x byName) Less(i, j int) bool {
97	if x[i].name != x[j].name {
98		return x[i].name < x[j].name
99	}
100	if len(x[i].index) != len(x[j].index) {
101		return len(x[i].index) < len(x[j].index)
102	}
103	if x[i].tag != x[j].tag {
104		return x[i].tag
105	}
106	return byIndex(x).Less(i, j)
107}
108
109// byIndex sorts field by index sequence.
110type byIndex []field
111
112func (x byIndex) Len() int { return len(x) }
113
114func (x byIndex) Swap(i, j int) { x[i], x[j] = x[j], x[i] }
115
116func (x byIndex) Less(i, j int) bool {
117	for k, xik := range x[i].index {
118		if k >= len(x[j].index) {
119			return false
120		}
121		if xik != x[j].index[k] {
122			return xik < x[j].index[k]
123		}
124	}
125	return len(x[i].index) < len(x[j].index)
126}
127
128// typeFields returns a list of fields that JSON should recognize for the given type.
129// The algorithm is breadth-first search over the set of structs to include - the top struct
130// and then any reachable anonymous structs.
131func typeFields(t reflect.Type) []field {
132	// Anonymous fields to explore at the current level and the next.
133	current := []field{}
134	next := []field{{typ: t}}
135
136	// Count of queued names for current level and the next.
137	count := map[reflect.Type]int{}
138	nextCount := map[reflect.Type]int{}
139
140	// Types already visited at an earlier level.
141	visited := map[reflect.Type]bool{}
142
143	// Fields found.
144	var fields []field
145
146	for len(next) > 0 {
147		current, next = next, current[:0]
148		count, nextCount = nextCount, map[reflect.Type]int{}
149
150		for _, f := range current {
151			if visited[f.typ] {
152				continue
153			}
154			visited[f.typ] = true
155
156			// Scan f.typ for fields to include.
157			for i := 0; i < f.typ.NumField(); i++ {
158				sf := f.typ.Field(i)
159				if sf.PkgPath != "" { // unexported
160					continue
161				}
162				tag := sf.Tag.Get("json")
163				if tag == "-" {
164					continue
165				}
166				name, opts := parseTag(tag)
167				if !isValidTag(name) {
168					name = ""
169				}
170				index := make([]int, len(f.index)+1)
171				copy(index, f.index)
172				index[len(f.index)] = i
173
174				ft := sf.Type
175				if ft.Name() == "" && ft.Kind() == reflect.Ptr {
176					// Follow pointer.
177					ft = ft.Elem()
178				}
179
180				// Record found field and index sequence.
181				if name != "" || !sf.Anonymous || ft.Kind() != reflect.Struct {
182					tagged := name != ""
183					if name == "" {
184						name = sf.Name
185					}
186					fields = append(fields, fillField(field{
187						name:      name,
188						tag:       tagged,
189						index:     index,
190						typ:       ft,
191						omitEmpty: opts.Contains("omitempty"),
192						quoted:    opts.Contains("string"),
193					}))
194					if count[f.typ] > 1 {
195						// If there were multiple instances, add a second,
196						// so that the annihilation code will see a duplicate.
197						// It only cares about the distinction between 1 or 2,
198						// so don't bother generating any more copies.
199						fields = append(fields, fields[len(fields)-1])
200					}
201					continue
202				}
203
204				// Record new anonymous struct to explore in next round.
205				nextCount[ft]++
206				if nextCount[ft] == 1 {
207					next = append(next, fillField(field{name: ft.Name(), index: index, typ: ft}))
208				}
209			}
210		}
211	}
212
213	sort.Sort(byName(fields))
214
215	// Delete all fields that are hidden by the Go rules for embedded fields,
216	// except that fields with JSON tags are promoted.
217
218	// The fields are sorted in primary order of name, secondary order
219	// of field index length. Loop over names; for each name, delete
220	// hidden fields by choosing the one dominant field that survives.
221	out := fields[:0]
222	for advance, i := 0, 0; i < len(fields); i += advance {
223		// One iteration per name.
224		// Find the sequence of fields with the name of this first field.
225		fi := fields[i]
226		name := fi.name
227		for advance = 1; i+advance < len(fields); advance++ {
228			fj := fields[i+advance]
229			if fj.name != name {
230				break
231			}
232		}
233		if advance == 1 { // Only one field with this name
234			out = append(out, fi)
235			continue
236		}
237		dominant, ok := dominantField(fields[i : i+advance])
238		if ok {
239			out = append(out, dominant)
240		}
241	}
242
243	fields = out
244	sort.Sort(byIndex(fields))
245
246	return fields
247}
248
249// dominantField looks through the fields, all of which are known to
250// have the same name, to find the single field that dominates the
251// others using Go's embedding rules, modified by the presence of
252// JSON tags. If there are multiple top-level fields, the boolean
253// will be false: This condition is an error in Go and we skip all
254// the fields.
255func dominantField(fields []field) (field, bool) {
256	// The fields are sorted in increasing index-length order. The winner
257	// must therefore be one with the shortest index length. Drop all
258	// longer entries, which is easy: just truncate the slice.
259	length := len(fields[0].index)
260	tagged := -1 // Index of first tagged field.
261	for i, f := range fields {
262		if len(f.index) > length {
263			fields = fields[:i]
264			break
265		}
266		if f.tag {
267			if tagged >= 0 {
268				// Multiple tagged fields at the same level: conflict.
269				// Return no field.
270				return field{}, false
271			}
272			tagged = i
273		}
274	}
275	if tagged >= 0 {
276		return fields[tagged], true
277	}
278	// All remaining fields have the same length. If there's more than one,
279	// we have a conflict (two fields named "X" at the same level) and we
280	// return no field.
281	if len(fields) > 1 {
282		return field{}, false
283	}
284	return fields[0], true
285}
286
287var fieldCache struct {
288	sync.RWMutex
289	m map[reflect.Type][]field
290}
291
292// cachedTypeFields is like typeFields but uses a cache to avoid repeated work.
293func cachedTypeFields(t reflect.Type) []field {
294	fieldCache.RLock()
295	f := fieldCache.m[t]
296	fieldCache.RUnlock()
297	if f != nil {
298		return f
299	}
300
301	// Compute fields without lock.
302	// Might duplicate effort but won't hold other computations back.
303	f = typeFields(t)
304	if f == nil {
305		f = []field{}
306	}
307
308	fieldCache.Lock()
309	if fieldCache.m == nil {
310		fieldCache.m = map[reflect.Type][]field{}
311	}
312	fieldCache.m[t] = f
313	fieldCache.Unlock()
314	return f
315}
316
317func isValidTag(s string) bool {
318	if s == "" {
319		return false
320	}
321	for _, c := range s {
322		switch {
323		case strings.ContainsRune("!#$%&()*+-./:<=>?@[]^_{|}~ ", c):
324			// Backslash and quote chars are reserved, but
325			// otherwise any punctuation chars are allowed
326			// in a tag name.
327		default:
328			if !unicode.IsLetter(c) && !unicode.IsDigit(c) {
329				return false
330			}
331		}
332	}
333	return true
334}
335
336const (
337	caseMask     = ^byte(0x20) // Mask to ignore case in ASCII.
338	kelvin       = '\u212a'
339	smallLongEss = '\u017f'
340)
341
342// foldFunc returns one of four different case folding equivalence
343// functions, from most general (and slow) to fastest:
344//
345// 1) bytes.EqualFold, if the key s contains any non-ASCII UTF-8
346// 2) equalFoldRight, if s contains special folding ASCII ('k', 'K', 's', 'S')
347// 3) asciiEqualFold, no special, but includes non-letters (including _)
348// 4) simpleLetterEqualFold, no specials, no non-letters.
349//
350// The letters S and K are special because they map to 3 runes, not just 2:
351//  * S maps to s and to U+017F 'ſ' Latin small letter long s
352//  * k maps to K and to U+212A 'K' Kelvin sign
353// See http://play.golang.org/p/tTxjOc0OGo
354//
355// The returned function is specialized for matching against s and
356// should only be given s. It's not curried for performance reasons.
357func foldFunc(s []byte) func(s, t []byte) bool {
358	nonLetter := false
359	special := false // special letter
360	for _, b := range s {
361		if b >= utf8.RuneSelf {
362			return bytes.EqualFold
363		}
364		upper := b & caseMask
365		if upper < 'A' || upper > 'Z' {
366			nonLetter = true
367		} else if upper == 'K' || upper == 'S' {
368			// See above for why these letters are special.
369			special = true
370		}
371	}
372	if special {
373		return equalFoldRight
374	}
375	if nonLetter {
376		return asciiEqualFold
377	}
378	return simpleLetterEqualFold
379}
380
381// equalFoldRight is a specialization of bytes.EqualFold when s is
382// known to be all ASCII (including punctuation), but contains an 's',
383// 'S', 'k', or 'K', requiring a Unicode fold on the bytes in t.
384// See comments on foldFunc.
385func equalFoldRight(s, t []byte) bool {
386	for _, sb := range s {
387		if len(t) == 0 {
388			return false
389		}
390		tb := t[0]
391		if tb < utf8.RuneSelf {
392			if sb != tb {
393				sbUpper := sb & caseMask
394				if 'A' <= sbUpper && sbUpper <= 'Z' {
395					if sbUpper != tb&caseMask {
396						return false
397					}
398				} else {
399					return false
400				}
401			}
402			t = t[1:]
403			continue
404		}
405		// sb is ASCII and t is not. t must be either kelvin
406		// sign or long s; sb must be s, S, k, or K.
407		tr, size := utf8.DecodeRune(t)
408		switch sb {
409		case 's', 'S':
410			if tr != smallLongEss {
411				return false
412			}
413		case 'k', 'K':
414			if tr != kelvin {
415				return false
416			}
417		default:
418			return false
419		}
420		t = t[size:]
421
422	}
423	if len(t) > 0 {
424		return false
425	}
426	return true
427}
428
429// asciiEqualFold is a specialization of bytes.EqualFold for use when
430// s is all ASCII (but may contain non-letters) and contains no
431// special-folding letters.
432// See comments on foldFunc.
433func asciiEqualFold(s, t []byte) bool {
434	if len(s) != len(t) {
435		return false
436	}
437	for i, sb := range s {
438		tb := t[i]
439		if sb == tb {
440			continue
441		}
442		if ('a' <= sb && sb <= 'z') || ('A' <= sb && sb <= 'Z') {
443			if sb&caseMask != tb&caseMask {
444				return false
445			}
446		} else {
447			return false
448		}
449	}
450	return true
451}
452
453// simpleLetterEqualFold is a specialization of bytes.EqualFold for
454// use when s is all ASCII letters (no underscores, etc) and also
455// doesn't contain 'k', 'K', 's', or 'S'.
456// See comments on foldFunc.
457func simpleLetterEqualFold(s, t []byte) bool {
458	if len(s) != len(t) {
459		return false
460	}
461	for i, b := range s {
462		if b&caseMask != t[i]&caseMask {
463			return false
464		}
465	}
466	return true
467}
468
469// tagOptions is the string following a comma in a struct field's "json"
470// tag, or the empty string. It does not include the leading comma.
471type tagOptions string
472
473// parseTag splits a struct field's json tag into its name and
474// comma-separated options.
475func parseTag(tag string) (string, tagOptions) {
476	if idx := strings.Index(tag, ","); idx != -1 {
477		return tag[:idx], tagOptions(tag[idx+1:])
478	}
479	return tag, tagOptions("")
480}
481
482// Contains reports whether a comma-separated list of options
483// contains a particular substr flag. substr must be surrounded by a
484// string boundary or commas.
485func (o tagOptions) Contains(optionName string) bool {
486	if len(o) == 0 {
487		return false
488	}
489	s := string(o)
490	for s != "" {
491		var next string
492		i := strings.Index(s, ",")
493		if i >= 0 {
494			s, next = s[:i], s[i+1:]
495		}
496		if s == optionName {
497			return true
498		}
499		s = next
500	}
501	return false
502}
503