1// Copyright 2010 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package proto
6
7import (
8	"errors"
9	"fmt"
10	"reflect"
11
12	"google.golang.org/protobuf/encoding/protowire"
13	"google.golang.org/protobuf/proto"
14	"google.golang.org/protobuf/reflect/protoreflect"
15	"google.golang.org/protobuf/reflect/protoregistry"
16	"google.golang.org/protobuf/runtime/protoiface"
17	"google.golang.org/protobuf/runtime/protoimpl"
18)
19
20type (
21	// ExtensionDesc represents an extension descriptor and
22	// is used to interact with an extension field in a message.
23	//
24	// Variables of this type are generated in code by protoc-gen-go.
25	ExtensionDesc = protoimpl.ExtensionInfo
26
27	// ExtensionRange represents a range of message extensions.
28	// Used in code generated by protoc-gen-go.
29	ExtensionRange = protoiface.ExtensionRangeV1
30
31	// Deprecated: Do not use; this is an internal type.
32	Extension = protoimpl.ExtensionFieldV1
33
34	// Deprecated: Do not use; this is an internal type.
35	XXX_InternalExtensions = protoimpl.ExtensionFields
36)
37
38// ErrMissingExtension reports whether the extension was not present.
39var ErrMissingExtension = errors.New("proto: missing extension")
40
41var errNotExtendable = errors.New("proto: not an extendable proto.Message")
42
43// HasExtension reports whether the extension field is present in m
44// either as an explicitly populated field or as an unknown field.
45func HasExtension(m Message, xt *ExtensionDesc) (has bool) {
46	mr := MessageReflect(m)
47	if mr == nil || !mr.IsValid() {
48		return false
49	}
50
51	// Check whether any populated known field matches the field number.
52	xtd := xt.TypeDescriptor()
53	if isValidExtension(mr.Descriptor(), xtd) {
54		has = mr.Has(xtd)
55	} else {
56		mr.Range(func(fd protoreflect.FieldDescriptor, _ protoreflect.Value) bool {
57			has = int32(fd.Number()) == xt.Field
58			return !has
59		})
60	}
61
62	// Check whether any unknown field matches the field number.
63	for b := mr.GetUnknown(); !has && len(b) > 0; {
64		num, _, n := protowire.ConsumeField(b)
65		has = int32(num) == xt.Field
66		b = b[n:]
67	}
68	return has
69}
70
71// ClearExtension removes the extension field from m
72// either as an explicitly populated field or as an unknown field.
73func ClearExtension(m Message, xt *ExtensionDesc) {
74	mr := MessageReflect(m)
75	if mr == nil || !mr.IsValid() {
76		return
77	}
78
79	xtd := xt.TypeDescriptor()
80	if isValidExtension(mr.Descriptor(), xtd) {
81		mr.Clear(xtd)
82	} else {
83		mr.Range(func(fd protoreflect.FieldDescriptor, _ protoreflect.Value) bool {
84			if int32(fd.Number()) == xt.Field {
85				mr.Clear(fd)
86				return false
87			}
88			return true
89		})
90	}
91	clearUnknown(mr, fieldNum(xt.Field))
92}
93
94// ClearAllExtensions clears all extensions from m.
95// This includes populated fields and unknown fields in the extension range.
96func ClearAllExtensions(m Message) {
97	mr := MessageReflect(m)
98	if mr == nil || !mr.IsValid() {
99		return
100	}
101
102	mr.Range(func(fd protoreflect.FieldDescriptor, _ protoreflect.Value) bool {
103		if fd.IsExtension() {
104			mr.Clear(fd)
105		}
106		return true
107	})
108	clearUnknown(mr, mr.Descriptor().ExtensionRanges())
109}
110
111// GetExtension retrieves a proto2 extended field from m.
112//
113// If the descriptor is type complete (i.e., ExtensionDesc.ExtensionType is non-nil),
114// then GetExtension parses the encoded field and returns a Go value of the specified type.
115// If the field is not present, then the default value is returned (if one is specified),
116// otherwise ErrMissingExtension is reported.
117//
118// If the descriptor is type incomplete (i.e., ExtensionDesc.ExtensionType is nil),
119// then GetExtension returns the raw encoded bytes for the extension field.
120func GetExtension(m Message, xt *ExtensionDesc) (interface{}, error) {
121	mr := MessageReflect(m)
122	if mr == nil || !mr.IsValid() || mr.Descriptor().ExtensionRanges().Len() == 0 {
123		return nil, errNotExtendable
124	}
125
126	// Retrieve the unknown fields for this extension field.
127	var bo protoreflect.RawFields
128	for bi := mr.GetUnknown(); len(bi) > 0; {
129		num, _, n := protowire.ConsumeField(bi)
130		if int32(num) == xt.Field {
131			bo = append(bo, bi[:n]...)
132		}
133		bi = bi[n:]
134	}
135
136	// For type incomplete descriptors, only retrieve the unknown fields.
137	if xt.ExtensionType == nil {
138		return []byte(bo), nil
139	}
140
141	// If the extension field only exists as unknown fields, unmarshal it.
142	// This is rarely done since proto.Unmarshal eagerly unmarshals extensions.
143	xtd := xt.TypeDescriptor()
144	if !isValidExtension(mr.Descriptor(), xtd) {
145		return nil, fmt.Errorf("proto: bad extended type; %T does not extend %T", xt.ExtendedType, m)
146	}
147	if !mr.Has(xtd) && len(bo) > 0 {
148		m2 := mr.New()
149		if err := (proto.UnmarshalOptions{
150			Resolver: extensionResolver{xt},
151		}.Unmarshal(bo, m2.Interface())); err != nil {
152			return nil, err
153		}
154		if m2.Has(xtd) {
155			mr.Set(xtd, m2.Get(xtd))
156			clearUnknown(mr, fieldNum(xt.Field))
157		}
158	}
159
160	// Check whether the message has the extension field set or a default.
161	var pv protoreflect.Value
162	switch {
163	case mr.Has(xtd):
164		pv = mr.Get(xtd)
165	case xtd.HasDefault():
166		pv = xtd.Default()
167	default:
168		return nil, ErrMissingExtension
169	}
170
171	v := xt.InterfaceOf(pv)
172	rv := reflect.ValueOf(v)
173	if isScalarKind(rv.Kind()) {
174		rv2 := reflect.New(rv.Type())
175		rv2.Elem().Set(rv)
176		v = rv2.Interface()
177	}
178	return v, nil
179}
180
181// extensionResolver is a custom extension resolver that stores a single
182// extension type that takes precedence over the global registry.
183type extensionResolver struct{ xt protoreflect.ExtensionType }
184
185func (r extensionResolver) FindExtensionByName(field protoreflect.FullName) (protoreflect.ExtensionType, error) {
186	if xtd := r.xt.TypeDescriptor(); xtd.FullName() == field {
187		return r.xt, nil
188	}
189	return protoregistry.GlobalTypes.FindExtensionByName(field)
190}
191
192func (r extensionResolver) FindExtensionByNumber(message protoreflect.FullName, field protoreflect.FieldNumber) (protoreflect.ExtensionType, error) {
193	if xtd := r.xt.TypeDescriptor(); xtd.ContainingMessage().FullName() == message && xtd.Number() == field {
194		return r.xt, nil
195	}
196	return protoregistry.GlobalTypes.FindExtensionByNumber(message, field)
197}
198
199// GetExtensions returns a list of the extensions values present in m,
200// corresponding with the provided list of extension descriptors, xts.
201// If an extension is missing in m, the corresponding value is nil.
202func GetExtensions(m Message, xts []*ExtensionDesc) ([]interface{}, error) {
203	mr := MessageReflect(m)
204	if mr == nil || !mr.IsValid() {
205		return nil, errNotExtendable
206	}
207
208	vs := make([]interface{}, len(xts))
209	for i, xt := range xts {
210		v, err := GetExtension(m, xt)
211		if err != nil {
212			if err == ErrMissingExtension {
213				continue
214			}
215			return vs, err
216		}
217		vs[i] = v
218	}
219	return vs, nil
220}
221
222// SetExtension sets an extension field in m to the provided value.
223func SetExtension(m Message, xt *ExtensionDesc, v interface{}) error {
224	mr := MessageReflect(m)
225	if mr == nil || !mr.IsValid() || mr.Descriptor().ExtensionRanges().Len() == 0 {
226		return errNotExtendable
227	}
228
229	rv := reflect.ValueOf(v)
230	if reflect.TypeOf(v) != reflect.TypeOf(xt.ExtensionType) {
231		return fmt.Errorf("proto: bad extension value type. got: %T, want: %T", v, xt.ExtensionType)
232	}
233	if rv.Kind() == reflect.Ptr {
234		if rv.IsNil() {
235			return fmt.Errorf("proto: SetExtension called with nil value of type %T", v)
236		}
237		if isScalarKind(rv.Elem().Kind()) {
238			v = rv.Elem().Interface()
239		}
240	}
241
242	xtd := xt.TypeDescriptor()
243	if !isValidExtension(mr.Descriptor(), xtd) {
244		return fmt.Errorf("proto: bad extended type; %T does not extend %T", xt.ExtendedType, m)
245	}
246	mr.Set(xtd, xt.ValueOf(v))
247	clearUnknown(mr, fieldNum(xt.Field))
248	return nil
249}
250
251// SetRawExtension inserts b into the unknown fields of m.
252//
253// Deprecated: Use Message.ProtoReflect.SetUnknown instead.
254func SetRawExtension(m Message, fnum int32, b []byte) {
255	mr := MessageReflect(m)
256	if mr == nil || !mr.IsValid() {
257		return
258	}
259
260	// Verify that the raw field is valid.
261	for b0 := b; len(b0) > 0; {
262		num, _, n := protowire.ConsumeField(b0)
263		if int32(num) != fnum {
264			panic(fmt.Sprintf("mismatching field number: got %d, want %d", num, fnum))
265		}
266		b0 = b0[n:]
267	}
268
269	ClearExtension(m, &ExtensionDesc{Field: fnum})
270	mr.SetUnknown(append(mr.GetUnknown(), b...))
271}
272
273// ExtensionDescs returns a list of extension descriptors found in m,
274// containing descriptors for both populated extension fields in m and
275// also unknown fields of m that are in the extension range.
276// For the later case, an type incomplete descriptor is provided where only
277// the ExtensionDesc.Field field is populated.
278// The order of the extension descriptors is undefined.
279func ExtensionDescs(m Message) ([]*ExtensionDesc, error) {
280	mr := MessageReflect(m)
281	if mr == nil || !mr.IsValid() || mr.Descriptor().ExtensionRanges().Len() == 0 {
282		return nil, errNotExtendable
283	}
284
285	// Collect a set of known extension descriptors.
286	extDescs := make(map[protoreflect.FieldNumber]*ExtensionDesc)
287	mr.Range(func(fd protoreflect.FieldDescriptor, v protoreflect.Value) bool {
288		if fd.IsExtension() {
289			xt := fd.(protoreflect.ExtensionTypeDescriptor)
290			if xd, ok := xt.Type().(*ExtensionDesc); ok {
291				extDescs[fd.Number()] = xd
292			}
293		}
294		return true
295	})
296
297	// Collect a set of unknown extension descriptors.
298	extRanges := mr.Descriptor().ExtensionRanges()
299	for b := mr.GetUnknown(); len(b) > 0; {
300		num, _, n := protowire.ConsumeField(b)
301		if extRanges.Has(num) && extDescs[num] == nil {
302			extDescs[num] = nil
303		}
304		b = b[n:]
305	}
306
307	// Transpose the set of descriptors into a list.
308	var xts []*ExtensionDesc
309	for num, xt := range extDescs {
310		if xt == nil {
311			xt = &ExtensionDesc{Field: int32(num)}
312		}
313		xts = append(xts, xt)
314	}
315	return xts, nil
316}
317
318// isValidExtension reports whether xtd is a valid extension descriptor for md.
319func isValidExtension(md protoreflect.MessageDescriptor, xtd protoreflect.ExtensionTypeDescriptor) bool {
320	return xtd.ContainingMessage() == md && md.ExtensionRanges().Has(xtd.Number())
321}
322
323// isScalarKind reports whether k is a protobuf scalar kind (except bytes).
324// This function exists for historical reasons since the representation of
325// scalars differs between v1 and v2, where v1 uses *T and v2 uses T.
326func isScalarKind(k reflect.Kind) bool {
327	switch k {
328	case reflect.Bool, reflect.Int32, reflect.Int64, reflect.Uint32, reflect.Uint64, reflect.Float32, reflect.Float64, reflect.String:
329		return true
330	default:
331		return false
332	}
333}
334
335// clearUnknown removes unknown fields from m where remover.Has reports true.
336func clearUnknown(m protoreflect.Message, remover interface {
337	Has(protoreflect.FieldNumber) bool
338}) {
339	var bo protoreflect.RawFields
340	for bi := m.GetUnknown(); len(bi) > 0; {
341		num, _, n := protowire.ConsumeField(bi)
342		if !remover.Has(num) {
343			bo = append(bo, bi[:n]...)
344		}
345		bi = bi[n:]
346	}
347	if bi := m.GetUnknown(); len(bi) != len(bo) {
348		m.SetUnknown(bo)
349	}
350}
351
352type fieldNum protoreflect.FieldNumber
353
354func (n1 fieldNum) Has(n2 protoreflect.FieldNumber) bool {
355	return protoreflect.FieldNumber(n1) == n2
356}
357