1// Copyright 2019 The Go Authors. All rights reserved. 2// Use of this source code is governed by a BSD-style 3// license that can be found in the LICENSE file. 4 5package impl 6 7import ( 8 "math/bits" 9 10 "google.golang.org/protobuf/encoding/protowire" 11 "google.golang.org/protobuf/internal/errors" 12 "google.golang.org/protobuf/internal/flags" 13 "google.golang.org/protobuf/proto" 14 "google.golang.org/protobuf/reflect/protoreflect" 15 preg "google.golang.org/protobuf/reflect/protoregistry" 16 "google.golang.org/protobuf/runtime/protoiface" 17 piface "google.golang.org/protobuf/runtime/protoiface" 18) 19 20type unmarshalOptions struct { 21 flags protoiface.UnmarshalInputFlags 22 resolver interface { 23 FindExtensionByName(field protoreflect.FullName) (protoreflect.ExtensionType, error) 24 FindExtensionByNumber(message protoreflect.FullName, field protoreflect.FieldNumber) (protoreflect.ExtensionType, error) 25 } 26} 27 28func (o unmarshalOptions) Options() proto.UnmarshalOptions { 29 return proto.UnmarshalOptions{ 30 Merge: true, 31 AllowPartial: true, 32 DiscardUnknown: o.DiscardUnknown(), 33 Resolver: o.resolver, 34 } 35} 36 37func (o unmarshalOptions) DiscardUnknown() bool { return o.flags&piface.UnmarshalDiscardUnknown != 0 } 38 39func (o unmarshalOptions) IsDefault() bool { 40 return o.flags == 0 && o.resolver == preg.GlobalTypes 41} 42 43var lazyUnmarshalOptions = unmarshalOptions{ 44 resolver: preg.GlobalTypes, 45} 46 47type unmarshalOutput struct { 48 n int // number of bytes consumed 49 initialized bool 50} 51 52// unmarshal is protoreflect.Methods.Unmarshal. 53func (mi *MessageInfo) unmarshal(in piface.UnmarshalInput) (piface.UnmarshalOutput, error) { 54 var p pointer 55 if ms, ok := in.Message.(*messageState); ok { 56 p = ms.pointer() 57 } else { 58 p = in.Message.(*messageReflectWrapper).pointer() 59 } 60 out, err := mi.unmarshalPointer(in.Buf, p, 0, unmarshalOptions{ 61 flags: in.Flags, 62 resolver: in.Resolver, 63 }) 64 var flags piface.UnmarshalOutputFlags 65 if out.initialized { 66 flags |= piface.UnmarshalInitialized 67 } 68 return piface.UnmarshalOutput{ 69 Flags: flags, 70 }, err 71} 72 73// errUnknown is returned during unmarshaling to indicate a parse error that 74// should result in a field being placed in the unknown fields section (for example, 75// when the wire type doesn't match) as opposed to the entire unmarshal operation 76// failing (for example, when a field extends past the available input). 77// 78// This is a sentinel error which should never be visible to the user. 79var errUnknown = errors.New("unknown") 80 81func (mi *MessageInfo) unmarshalPointer(b []byte, p pointer, groupTag protowire.Number, opts unmarshalOptions) (out unmarshalOutput, err error) { 82 mi.init() 83 if flags.ProtoLegacy && mi.isMessageSet { 84 return unmarshalMessageSet(mi, b, p, opts) 85 } 86 initialized := true 87 var requiredMask uint64 88 var exts *map[int32]ExtensionField 89 start := len(b) 90 for len(b) > 0 { 91 // Parse the tag (field number and wire type). 92 var tag uint64 93 if b[0] < 0x80 { 94 tag = uint64(b[0]) 95 b = b[1:] 96 } else if len(b) >= 2 && b[1] < 128 { 97 tag = uint64(b[0]&0x7f) + uint64(b[1])<<7 98 b = b[2:] 99 } else { 100 var n int 101 tag, n = protowire.ConsumeVarint(b) 102 if n < 0 { 103 return out, protowire.ParseError(n) 104 } 105 b = b[n:] 106 } 107 var num protowire.Number 108 if n := tag >> 3; n < uint64(protowire.MinValidNumber) || n > uint64(protowire.MaxValidNumber) { 109 return out, errors.New("invalid field number") 110 } else { 111 num = protowire.Number(n) 112 } 113 wtyp := protowire.Type(tag & 7) 114 115 if wtyp == protowire.EndGroupType { 116 if num != groupTag { 117 return out, errors.New("mismatching end group marker") 118 } 119 groupTag = 0 120 break 121 } 122 123 var f *coderFieldInfo 124 if int(num) < len(mi.denseCoderFields) { 125 f = mi.denseCoderFields[num] 126 } else { 127 f = mi.coderFields[num] 128 } 129 var n int 130 err := errUnknown 131 switch { 132 case f != nil: 133 if f.funcs.unmarshal == nil { 134 break 135 } 136 var o unmarshalOutput 137 o, err = f.funcs.unmarshal(b, p.Apply(f.offset), wtyp, f, opts) 138 n = o.n 139 if err != nil { 140 break 141 } 142 requiredMask |= f.validation.requiredBit 143 if f.funcs.isInit != nil && !o.initialized { 144 initialized = false 145 } 146 default: 147 // Possible extension. 148 if exts == nil && mi.extensionOffset.IsValid() { 149 exts = p.Apply(mi.extensionOffset).Extensions() 150 if *exts == nil { 151 *exts = make(map[int32]ExtensionField) 152 } 153 } 154 if exts == nil { 155 break 156 } 157 var o unmarshalOutput 158 o, err = mi.unmarshalExtension(b, num, wtyp, *exts, opts) 159 if err != nil { 160 break 161 } 162 n = o.n 163 if !o.initialized { 164 initialized = false 165 } 166 } 167 if err != nil { 168 if err != errUnknown { 169 return out, err 170 } 171 n = protowire.ConsumeFieldValue(num, wtyp, b) 172 if n < 0 { 173 return out, protowire.ParseError(n) 174 } 175 if !opts.DiscardUnknown() && mi.unknownOffset.IsValid() { 176 u := p.Apply(mi.unknownOffset).Bytes() 177 *u = protowire.AppendTag(*u, num, wtyp) 178 *u = append(*u, b[:n]...) 179 } 180 } 181 b = b[n:] 182 } 183 if groupTag != 0 { 184 return out, errors.New("missing end group marker") 185 } 186 if mi.numRequiredFields > 0 && bits.OnesCount64(requiredMask) != int(mi.numRequiredFields) { 187 initialized = false 188 } 189 if initialized { 190 out.initialized = true 191 } 192 out.n = start - len(b) 193 return out, nil 194} 195 196func (mi *MessageInfo) unmarshalExtension(b []byte, num protowire.Number, wtyp protowire.Type, exts map[int32]ExtensionField, opts unmarshalOptions) (out unmarshalOutput, err error) { 197 x := exts[int32(num)] 198 xt := x.Type() 199 if xt == nil { 200 var err error 201 xt, err = opts.resolver.FindExtensionByNumber(mi.Desc.FullName(), num) 202 if err != nil { 203 if err == preg.NotFound { 204 return out, errUnknown 205 } 206 return out, errors.New("%v: unable to resolve extension %v: %v", mi.Desc.FullName(), num, err) 207 } 208 } 209 xi := getExtensionFieldInfo(xt) 210 if xi.funcs.unmarshal == nil { 211 return out, errUnknown 212 } 213 if flags.LazyUnmarshalExtensions { 214 if opts.IsDefault() && x.canLazy(xt) { 215 out, valid := skipExtension(b, xi, num, wtyp, opts) 216 switch valid { 217 case ValidationValid: 218 if out.initialized { 219 x.appendLazyBytes(xt, xi, num, wtyp, b[:out.n]) 220 exts[int32(num)] = x 221 return out, nil 222 } 223 case ValidationInvalid: 224 return out, errors.New("invalid wire format") 225 case ValidationUnknown: 226 } 227 } 228 } 229 ival := x.Value() 230 if !ival.IsValid() && xi.unmarshalNeedsValue { 231 // Create a new message, list, or map value to fill in. 232 // For enums, create a prototype value to let the unmarshal func know the 233 // concrete type. 234 ival = xt.New() 235 } 236 v, out, err := xi.funcs.unmarshal(b, ival, num, wtyp, opts) 237 if err != nil { 238 return out, err 239 } 240 if xi.funcs.isInit == nil { 241 out.initialized = true 242 } 243 x.Set(xt, v) 244 exts[int32(num)] = x 245 return out, nil 246} 247 248func skipExtension(b []byte, xi *extensionFieldInfo, num protowire.Number, wtyp protowire.Type, opts unmarshalOptions) (out unmarshalOutput, _ ValidationStatus) { 249 if xi.validation.mi == nil { 250 return out, ValidationUnknown 251 } 252 xi.validation.mi.init() 253 switch xi.validation.typ { 254 case validationTypeMessage: 255 if wtyp != protowire.BytesType { 256 return out, ValidationUnknown 257 } 258 v, n := protowire.ConsumeBytes(b) 259 if n < 0 { 260 return out, ValidationUnknown 261 } 262 out, st := xi.validation.mi.validate(v, 0, opts) 263 out.n = n 264 return out, st 265 case validationTypeGroup: 266 if wtyp != protowire.StartGroupType { 267 return out, ValidationUnknown 268 } 269 out, st := xi.validation.mi.validate(b, num, opts) 270 return out, st 271 default: 272 return out, ValidationUnknown 273 } 274} 275