1package yaml
2
3import (
4	"errors"
5	"fmt"
6	"math"
7	"reflect"
8	"strings"
9
10	goyaml "gopkg.in/yaml.v3"
11
12	"github.com/sclevine/yj/order"
13)
14
15// NOTE: some parts copied from gopkg.in/yaml.v3
16
17const (
18	// 400,000 decode operations is ~500kb of dense object declarations, or
19	// ~5kb of dense object declarations with 10000% alias expansion
20	aliasRatioRangeLow = 400000
21
22	// 4,000,000 decode operations is ~5MB of dense object declarations, or
23	// ~4.5MB of dense object declarations with 10% alias expansion
24	aliasRatioRangeHigh = 4000000
25
26	// aliasRatioRange is the range over which we scale allowed alias ratios
27	aliasRatioRange = float64(aliasRatioRangeHigh - aliasRatioRangeLow)
28
29	longTagPrefix = "tag:yaml.org,2002:"
30	mergeTag      = "!!merge"
31)
32
33var ErrNotMaps = errors.New("map merge requires map or sequence of maps as the value")
34
35type Decoder struct {
36	KeyMarshal func(interface{}) ([]byte, error)
37
38	// If set, NaN, Inf, etc. are replaced by the set values
39	NaN, PosInf, NegInf          interface{}
40	KeyNaN, KeyPosInf, KeyNegInf interface{}
41}
42
43// Decode decodes a YAML node tree into an the normalized object format.
44func (d *Decoder) Decode(node *goyaml.Node) (normal interface{}, err error) {
45	defer catchFailure(&err)
46	dt := decodeTracker{Decoder: d}
47	return dt.normalize(node), nil
48}
49
50type decodeTracker struct {
51	*Decoder
52	aliases     map[*goyaml.Node]struct{}
53	doc         *goyaml.Node
54	decodeCount int
55	aliasCount  int
56	aliasDepth  int
57}
58
59func (d *decodeTracker) normalize(n *goyaml.Node) interface{} {
60	d.decodeCount++
61	if d.aliasDepth > 0 {
62		d.aliasCount++
63	}
64	if d.aliasCount > 100 && d.decodeCount > 1000 && float64(d.aliasCount)/float64(d.decodeCount) > allowedAliasRatio(d.decodeCount) {
65		panic(fmt.Errorf("document contains excessive aliasing"))
66	}
67
68	switch n.Kind {
69	case goyaml.DocumentNode:
70		return d.document(n)
71	case goyaml.AliasNode:
72		return d.alias(n)
73	case goyaml.ScalarNode:
74		var out interface{}
75		if err := n.Decode(&out); err != nil {
76			panic(fmt.Errorf("scalar decode error: %s", err))
77		}
78		switch out := out.(type) {
79		case float64:
80			return d.float(out)
81		}
82		return d.other(out)
83	case goyaml.MappingNode:
84		return d.mapping(n)
85	case goyaml.SequenceNode:
86		return d.sequence(n)
87	case 0:
88		if n.IsZero() {
89			return nil
90		}
91		fallthrough
92	default:
93		panic(fmt.Errorf("cannot decode node with unknown kind %d", n.Kind))
94	}
95}
96
97func (d *decodeTracker) other(in interface{}) interface{} {
98	switch reflect.ValueOf(in).Kind() {
99	case reflect.Map, reflect.Array, reflect.Slice, reflect.Float32:
100		panic(fmt.Errorf("unexpected type: %#v", in))
101	}
102	return in
103}
104
105func (d *decodeTracker) float(in float64) interface{} {
106	switch {
107	case d.NaN != nil && math.IsNaN(in):
108		return d.NaN
109	case d.PosInf != nil && math.IsInf(in, 1):
110		return d.PosInf
111	case d.NegInf != nil && math.IsInf(in, -1):
112		return d.NegInf
113	}
114	return in
115}
116
117func (d *decodeTracker) key(n *goyaml.Node) string {
118	// Decoder remains reentrant, but decodeTracker need not be
119	defer func(dec *Decoder) {
120		d.Decoder = dec
121	}(d.Decoder)
122	kdec := *d.Decoder
123	kdec.NaN = d.KeyNaN
124	kdec.PosInf = d.KeyPosInf
125	kdec.NegInf = d.KeyNegInf
126	d.Decoder = &kdec
127	switch key := d.normalize(n).(type) {
128	case string:
129		return key
130	case fmt.Stringer:
131		return key.String()
132	default:
133		out, err := d.KeyMarshal(key)
134		if err != nil {
135			panic(err)
136		}
137		return string(out)
138	}
139}
140
141func allowedAliasRatio(decodeCount int) float64 {
142	switch {
143	case decodeCount <= aliasRatioRangeLow:
144		// allow 99% to come from alias expansion for small-to-medium documents
145		return 0.99
146	case decodeCount >= aliasRatioRangeHigh:
147		// allow 10% to come from alias expansion for very large documents
148		return 0.10
149	default:
150		// scale smoothly from 99% down to 10% over the range.
151		// this maps to 396,000 - 400,000 allowed alias-driven decodes over the range.
152		// 400,000 decode operations is ~100MB of allocations in worst-case scenarios (single-item maps).
153		return 0.99 - 0.89*(float64(decodeCount-aliasRatioRangeLow)/aliasRatioRange)
154	}
155}
156
157func (d *decodeTracker) document(n *goyaml.Node) interface{} {
158	if len(n.Content) != 1 {
159		panic(fmt.Errorf("invalid document"))
160	}
161	d.doc = n
162	return d.normalize(n.Content[0])
163}
164
165func (d *decodeTracker) alias(n *goyaml.Node) interface{} {
166	if d.aliases == nil {
167		d.aliases = make(map[*goyaml.Node]struct{})
168	}
169	if _, ok := d.aliases[n]; ok {
170		// TODO this could actually be allowed in some circumstances.
171		panic(fmt.Errorf("anchor '%s' value contains itself", n.Value))
172	}
173	d.aliases[n] = struct{}{}
174	d.aliasDepth++
175	out := d.normalize(n.Alias)
176	d.aliasDepth--
177	delete(d.aliases, n)
178	return out
179}
180
181func (d *decodeTracker) sequence(n *goyaml.Node) []interface{} {
182	out := make([]interface{}, 0, len(n.Content))
183	for _, c := range n.Content {
184		out = append(out, d.normalize(c))
185	}
186	return out
187}
188
189func shortTag(tag string) string {
190	// TODO This can easily be made faster and produce less garbage.
191	if strings.HasPrefix(tag, longTagPrefix) {
192		return "!!" + tag[len(longTagPrefix):]
193	}
194	return tag
195}
196
197func isMerge(n *goyaml.Node) bool {
198	return n.Kind == goyaml.ScalarNode && n.Value == "<<" && (n.Tag == "" || n.Tag == "!" || shortTag(n.Tag) == mergeTag)
199}
200
201func (d *decodeTracker) mapping(n *goyaml.Node) order.MapSlice {
202	l := len(n.Content)
203	out := make(order.MapSlice, 0, l/2)
204
205	for i := 0; i < l; i += 2 {
206		if isMerge(n.Content[i]) {
207			out = d.merge(out, n.Content[i+1])
208			continue
209		}
210		out = append(out, order.MapItem{
211			Key: d.key(n.Content[i]),
212			Val: d.normalize(n.Content[i+1]),
213		})
214	}
215	return out
216}
217
218func (d *decodeTracker) merge(m order.MapSlice, n *goyaml.Node) order.MapSlice {
219	switch n.Kind {
220	case goyaml.AliasNode:
221		if n.Alias != nil && n.Alias.Kind != goyaml.MappingNode {
222			panic(ErrNotMaps)
223		}
224		fallthrough
225	case goyaml.MappingNode:
226		in, ok := d.normalize(n).(order.MapSlice)
227		if !ok {
228			panic(ErrNotMaps)
229		}
230		return m.Merge(in)
231	case goyaml.SequenceNode:
232		// Step backwards as earlier nodes take precedence.
233		for i := len(n.Content) - 1; i >= 0; i-- {
234			ni := n.Content[i]
235			if ni.Kind == goyaml.AliasNode {
236				if ni.Alias != nil && ni.Alias.Kind != goyaml.MappingNode {
237					panic(ErrNotMaps)
238				}
239			} else if ni.Kind != goyaml.MappingNode {
240				panic(ErrNotMaps)
241			}
242			in, ok := d.normalize(n).(order.MapSlice)
243			if !ok {
244				panic(ErrNotMaps)
245			}
246			m = m.Merge(in)
247		}
248		return m
249	default:
250		panic(ErrNotMaps)
251	}
252}
253