1// Copyright 2019 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package impl
6
7import (
8	"math/bits"
9
10	"google.golang.org/protobuf/encoding/protowire"
11	"google.golang.org/protobuf/internal/errors"
12	"google.golang.org/protobuf/internal/flags"
13	"google.golang.org/protobuf/proto"
14	"google.golang.org/protobuf/reflect/protoreflect"
15	preg "google.golang.org/protobuf/reflect/protoregistry"
16	"google.golang.org/protobuf/runtime/protoiface"
17	piface "google.golang.org/protobuf/runtime/protoiface"
18)
19
20type unmarshalOptions struct {
21	flags    protoiface.UnmarshalInputFlags
22	resolver interface {
23		FindExtensionByName(field protoreflect.FullName) (protoreflect.ExtensionType, error)
24		FindExtensionByNumber(message protoreflect.FullName, field protoreflect.FieldNumber) (protoreflect.ExtensionType, error)
25	}
26}
27
28func (o unmarshalOptions) Options() proto.UnmarshalOptions {
29	return proto.UnmarshalOptions{
30		Merge:          true,
31		AllowPartial:   true,
32		DiscardUnknown: o.DiscardUnknown(),
33		Resolver:       o.resolver,
34	}
35}
36
37func (o unmarshalOptions) DiscardUnknown() bool { return o.flags&piface.UnmarshalDiscardUnknown != 0 }
38
39func (o unmarshalOptions) IsDefault() bool {
40	return o.flags == 0 && o.resolver == preg.GlobalTypes
41}
42
43var lazyUnmarshalOptions = unmarshalOptions{
44	resolver: preg.GlobalTypes,
45}
46
47type unmarshalOutput struct {
48	n           int // number of bytes consumed
49	initialized bool
50}
51
52// unmarshal is protoreflect.Methods.Unmarshal.
53func (mi *MessageInfo) unmarshal(in piface.UnmarshalInput) (piface.UnmarshalOutput, error) {
54	var p pointer
55	if ms, ok := in.Message.(*messageState); ok {
56		p = ms.pointer()
57	} else {
58		p = in.Message.(*messageReflectWrapper).pointer()
59	}
60	out, err := mi.unmarshalPointer(in.Buf, p, 0, unmarshalOptions{
61		flags:    in.Flags,
62		resolver: in.Resolver,
63	})
64	var flags piface.UnmarshalOutputFlags
65	if out.initialized {
66		flags |= piface.UnmarshalInitialized
67	}
68	return piface.UnmarshalOutput{
69		Flags: flags,
70	}, err
71}
72
73// errUnknown is returned during unmarshaling to indicate a parse error that
74// should result in a field being placed in the unknown fields section (for example,
75// when the wire type doesn't match) as opposed to the entire unmarshal operation
76// failing (for example, when a field extends past the available input).
77//
78// This is a sentinel error which should never be visible to the user.
79var errUnknown = errors.New("unknown")
80
81func (mi *MessageInfo) unmarshalPointer(b []byte, p pointer, groupTag protowire.Number, opts unmarshalOptions) (out unmarshalOutput, err error) {
82	mi.init()
83	if flags.ProtoLegacy && mi.isMessageSet {
84		return unmarshalMessageSet(mi, b, p, opts)
85	}
86	initialized := true
87	var requiredMask uint64
88	var exts *map[int32]ExtensionField
89	start := len(b)
90	for len(b) > 0 {
91		// Parse the tag (field number and wire type).
92		var tag uint64
93		if b[0] < 0x80 {
94			tag = uint64(b[0])
95			b = b[1:]
96		} else if len(b) >= 2 && b[1] < 128 {
97			tag = uint64(b[0]&0x7f) + uint64(b[1])<<7
98			b = b[2:]
99		} else {
100			var n int
101			tag, n = protowire.ConsumeVarint(b)
102			if n < 0 {
103				return out, protowire.ParseError(n)
104			}
105			b = b[n:]
106		}
107		var num protowire.Number
108		if n := tag >> 3; n < uint64(protowire.MinValidNumber) || n > uint64(protowire.MaxValidNumber) {
109			return out, errors.New("invalid field number")
110		} else {
111			num = protowire.Number(n)
112		}
113		wtyp := protowire.Type(tag & 7)
114
115		if wtyp == protowire.EndGroupType {
116			if num != groupTag {
117				return out, errors.New("mismatching end group marker")
118			}
119			groupTag = 0
120			break
121		}
122
123		var f *coderFieldInfo
124		if int(num) < len(mi.denseCoderFields) {
125			f = mi.denseCoderFields[num]
126		} else {
127			f = mi.coderFields[num]
128		}
129		var n int
130		err := errUnknown
131		switch {
132		case f != nil:
133			if f.funcs.unmarshal == nil {
134				break
135			}
136			var o unmarshalOutput
137			o, err = f.funcs.unmarshal(b, p.Apply(f.offset), wtyp, f, opts)
138			n = o.n
139			if err != nil {
140				break
141			}
142			requiredMask |= f.validation.requiredBit
143			if f.funcs.isInit != nil && !o.initialized {
144				initialized = false
145			}
146		default:
147			// Possible extension.
148			if exts == nil && mi.extensionOffset.IsValid() {
149				exts = p.Apply(mi.extensionOffset).Extensions()
150				if *exts == nil {
151					*exts = make(map[int32]ExtensionField)
152				}
153			}
154			if exts == nil {
155				break
156			}
157			var o unmarshalOutput
158			o, err = mi.unmarshalExtension(b, num, wtyp, *exts, opts)
159			if err != nil {
160				break
161			}
162			n = o.n
163			if !o.initialized {
164				initialized = false
165			}
166		}
167		if err != nil {
168			if err != errUnknown {
169				return out, err
170			}
171			n = protowire.ConsumeFieldValue(num, wtyp, b)
172			if n < 0 {
173				return out, protowire.ParseError(n)
174			}
175			if !opts.DiscardUnknown() && mi.unknownOffset.IsValid() {
176				u := p.Apply(mi.unknownOffset).Bytes()
177				*u = protowire.AppendTag(*u, num, wtyp)
178				*u = append(*u, b[:n]...)
179			}
180		}
181		b = b[n:]
182	}
183	if groupTag != 0 {
184		return out, errors.New("missing end group marker")
185	}
186	if mi.numRequiredFields > 0 && bits.OnesCount64(requiredMask) != int(mi.numRequiredFields) {
187		initialized = false
188	}
189	if initialized {
190		out.initialized = true
191	}
192	out.n = start - len(b)
193	return out, nil
194}
195
196func (mi *MessageInfo) unmarshalExtension(b []byte, num protowire.Number, wtyp protowire.Type, exts map[int32]ExtensionField, opts unmarshalOptions) (out unmarshalOutput, err error) {
197	x := exts[int32(num)]
198	xt := x.Type()
199	if xt == nil {
200		var err error
201		xt, err = opts.resolver.FindExtensionByNumber(mi.Desc.FullName(), num)
202		if err != nil {
203			if err == preg.NotFound {
204				return out, errUnknown
205			}
206			return out, errors.New("%v: unable to resolve extension %v: %v", mi.Desc.FullName(), num, err)
207		}
208	}
209	xi := getExtensionFieldInfo(xt)
210	if xi.funcs.unmarshal == nil {
211		return out, errUnknown
212	}
213	if flags.LazyUnmarshalExtensions {
214		if opts.IsDefault() && x.canLazy(xt) {
215			out, valid := skipExtension(b, xi, num, wtyp, opts)
216			switch valid {
217			case ValidationValid:
218				if out.initialized {
219					x.appendLazyBytes(xt, xi, num, wtyp, b[:out.n])
220					exts[int32(num)] = x
221					return out, nil
222				}
223			case ValidationInvalid:
224				return out, errors.New("invalid wire format")
225			case ValidationUnknown:
226			}
227		}
228	}
229	ival := x.Value()
230	if !ival.IsValid() && xi.unmarshalNeedsValue {
231		// Create a new message, list, or map value to fill in.
232		// For enums, create a prototype value to let the unmarshal func know the
233		// concrete type.
234		ival = xt.New()
235	}
236	v, out, err := xi.funcs.unmarshal(b, ival, num, wtyp, opts)
237	if err != nil {
238		return out, err
239	}
240	if xi.funcs.isInit == nil {
241		out.initialized = true
242	}
243	x.Set(xt, v)
244	exts[int32(num)] = x
245	return out, nil
246}
247
248func skipExtension(b []byte, xi *extensionFieldInfo, num protowire.Number, wtyp protowire.Type, opts unmarshalOptions) (out unmarshalOutput, _ ValidationStatus) {
249	if xi.validation.mi == nil {
250		return out, ValidationUnknown
251	}
252	xi.validation.mi.init()
253	switch xi.validation.typ {
254	case validationTypeMessage:
255		if wtyp != protowire.BytesType {
256			return out, ValidationUnknown
257		}
258		v, n := protowire.ConsumeBytes(b)
259		if n < 0 {
260			return out, ValidationUnknown
261		}
262		out, st := xi.validation.mi.validate(v, 0, opts)
263		out.n = n
264		return out, st
265	case validationTypeGroup:
266		if wtyp != protowire.StartGroupType {
267			return out, ValidationUnknown
268		}
269		out, st := xi.validation.mi.validate(b, num, opts)
270		return out, st
271	default:
272		return out, ValidationUnknown
273	}
274}
275