1package yaml 2 3import ( 4 "errors" 5 "fmt" 6 "math" 7 "reflect" 8 "strings" 9 10 goyaml "gopkg.in/yaml.v3" 11 12 "github.com/sclevine/yj/order" 13) 14 15// NOTE: some parts copied from gopkg.in/yaml.v3 16 17const ( 18 // 400,000 decode operations is ~500kb of dense object declarations, or 19 // ~5kb of dense object declarations with 10000% alias expansion 20 aliasRatioRangeLow = 400000 21 22 // 4,000,000 decode operations is ~5MB of dense object declarations, or 23 // ~4.5MB of dense object declarations with 10% alias expansion 24 aliasRatioRangeHigh = 4000000 25 26 // aliasRatioRange is the range over which we scale allowed alias ratios 27 aliasRatioRange = float64(aliasRatioRangeHigh - aliasRatioRangeLow) 28 29 longTagPrefix = "tag:yaml.org,2002:" 30 mergeTag = "!!merge" 31) 32 33var ErrNotMaps = errors.New("map merge requires map or sequence of maps as the value") 34 35type Decoder struct { 36 KeyMarshal func(interface{}) ([]byte, error) 37 38 // If set, NaN, Inf, etc. are replaced by the set values 39 NaN, PosInf, NegInf interface{} 40 KeyNaN, KeyPosInf, KeyNegInf interface{} 41} 42 43// Decode decodes a YAML node tree into an the normalized object format. 44func (d *Decoder) Decode(node *goyaml.Node) (normal interface{}, err error) { 45 defer catchFailure(&err) 46 dt := decodeTracker{Decoder: d} 47 return dt.normalize(node), nil 48} 49 50type decodeTracker struct { 51 *Decoder 52 aliases map[*goyaml.Node]struct{} 53 doc *goyaml.Node 54 decodeCount int 55 aliasCount int 56 aliasDepth int 57} 58 59func (d *decodeTracker) normalize(n *goyaml.Node) interface{} { 60 d.decodeCount++ 61 if d.aliasDepth > 0 { 62 d.aliasCount++ 63 } 64 if d.aliasCount > 100 && d.decodeCount > 1000 && float64(d.aliasCount)/float64(d.decodeCount) > allowedAliasRatio(d.decodeCount) { 65 panic(fmt.Errorf("document contains excessive aliasing")) 66 } 67 68 switch n.Kind { 69 case goyaml.DocumentNode: 70 return d.document(n) 71 case goyaml.AliasNode: 72 return d.alias(n) 73 case goyaml.ScalarNode: 74 var out interface{} 75 if err := n.Decode(&out); err != nil { 76 panic(fmt.Errorf("scalar decode error: %s", err)) 77 } 78 switch out := out.(type) { 79 case float64: 80 return d.float(out) 81 } 82 return d.other(out) 83 case goyaml.MappingNode: 84 return d.mapping(n) 85 case goyaml.SequenceNode: 86 return d.sequence(n) 87 case 0: 88 if n.IsZero() { 89 return nil 90 } 91 fallthrough 92 default: 93 panic(fmt.Errorf("cannot decode node with unknown kind %d", n.Kind)) 94 } 95} 96 97func (d *decodeTracker) other(in interface{}) interface{} { 98 switch reflect.ValueOf(in).Kind() { 99 case reflect.Map, reflect.Array, reflect.Slice, reflect.Float32: 100 panic(fmt.Errorf("unexpected type: %#v", in)) 101 } 102 return in 103} 104 105func (d *decodeTracker) float(in float64) interface{} { 106 switch { 107 case d.NaN != nil && math.IsNaN(in): 108 return d.NaN 109 case d.PosInf != nil && math.IsInf(in, 1): 110 return d.PosInf 111 case d.NegInf != nil && math.IsInf(in, -1): 112 return d.NegInf 113 } 114 return in 115} 116 117func (d *decodeTracker) key(n *goyaml.Node) string { 118 // Decoder remains reentrant, but decodeTracker need not be 119 defer func(dec *Decoder) { 120 d.Decoder = dec 121 }(d.Decoder) 122 kdec := *d.Decoder 123 kdec.NaN = d.KeyNaN 124 kdec.PosInf = d.KeyPosInf 125 kdec.NegInf = d.KeyNegInf 126 d.Decoder = &kdec 127 switch key := d.normalize(n).(type) { 128 case string: 129 return key 130 case fmt.Stringer: 131 return key.String() 132 default: 133 out, err := d.KeyMarshal(key) 134 if err != nil { 135 panic(err) 136 } 137 return string(out) 138 } 139} 140 141func allowedAliasRatio(decodeCount int) float64 { 142 switch { 143 case decodeCount <= aliasRatioRangeLow: 144 // allow 99% to come from alias expansion for small-to-medium documents 145 return 0.99 146 case decodeCount >= aliasRatioRangeHigh: 147 // allow 10% to come from alias expansion for very large documents 148 return 0.10 149 default: 150 // scale smoothly from 99% down to 10% over the range. 151 // this maps to 396,000 - 400,000 allowed alias-driven decodes over the range. 152 // 400,000 decode operations is ~100MB of allocations in worst-case scenarios (single-item maps). 153 return 0.99 - 0.89*(float64(decodeCount-aliasRatioRangeLow)/aliasRatioRange) 154 } 155} 156 157func (d *decodeTracker) document(n *goyaml.Node) interface{} { 158 if len(n.Content) != 1 { 159 panic(fmt.Errorf("invalid document")) 160 } 161 d.doc = n 162 return d.normalize(n.Content[0]) 163} 164 165func (d *decodeTracker) alias(n *goyaml.Node) interface{} { 166 if d.aliases == nil { 167 d.aliases = make(map[*goyaml.Node]struct{}) 168 } 169 if _, ok := d.aliases[n]; ok { 170 // TODO this could actually be allowed in some circumstances. 171 panic(fmt.Errorf("anchor '%s' value contains itself", n.Value)) 172 } 173 d.aliases[n] = struct{}{} 174 d.aliasDepth++ 175 out := d.normalize(n.Alias) 176 d.aliasDepth-- 177 delete(d.aliases, n) 178 return out 179} 180 181func (d *decodeTracker) sequence(n *goyaml.Node) []interface{} { 182 out := make([]interface{}, 0, len(n.Content)) 183 for _, c := range n.Content { 184 out = append(out, d.normalize(c)) 185 } 186 return out 187} 188 189func shortTag(tag string) string { 190 // TODO This can easily be made faster and produce less garbage. 191 if strings.HasPrefix(tag, longTagPrefix) { 192 return "!!" + tag[len(longTagPrefix):] 193 } 194 return tag 195} 196 197func isMerge(n *goyaml.Node) bool { 198 return n.Kind == goyaml.ScalarNode && n.Value == "<<" && (n.Tag == "" || n.Tag == "!" || shortTag(n.Tag) == mergeTag) 199} 200 201func (d *decodeTracker) mapping(n *goyaml.Node) order.MapSlice { 202 l := len(n.Content) 203 out := make(order.MapSlice, 0, l/2) 204 205 for i := 0; i < l; i += 2 { 206 if isMerge(n.Content[i]) { 207 out = d.merge(out, n.Content[i+1]) 208 continue 209 } 210 out = append(out, order.MapItem{ 211 Key: d.key(n.Content[i]), 212 Val: d.normalize(n.Content[i+1]), 213 }) 214 } 215 return out 216} 217 218func (d *decodeTracker) merge(m order.MapSlice, n *goyaml.Node) order.MapSlice { 219 switch n.Kind { 220 case goyaml.AliasNode: 221 if n.Alias != nil && n.Alias.Kind != goyaml.MappingNode { 222 panic(ErrNotMaps) 223 } 224 fallthrough 225 case goyaml.MappingNode: 226 in, ok := d.normalize(n).(order.MapSlice) 227 if !ok { 228 panic(ErrNotMaps) 229 } 230 return m.Merge(in) 231 case goyaml.SequenceNode: 232 // Step backwards as earlier nodes take precedence. 233 for i := len(n.Content) - 1; i >= 0; i-- { 234 ni := n.Content[i] 235 if ni.Kind == goyaml.AliasNode { 236 if ni.Alias != nil && ni.Alias.Kind != goyaml.MappingNode { 237 panic(ErrNotMaps) 238 } 239 } else if ni.Kind != goyaml.MappingNode { 240 panic(ErrNotMaps) 241 } 242 in, ok := d.normalize(n).(order.MapSlice) 243 if !ok { 244 panic(ErrNotMaps) 245 } 246 m = m.Merge(in) 247 } 248 return m 249 default: 250 panic(ErrNotMaps) 251 } 252} 253