1// Copyright 2019 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package impl
6
7import (
8	"fmt"
9	"reflect"
10
11	"google.golang.org/protobuf/internal/pragma"
12	pref "google.golang.org/protobuf/reflect/protoreflect"
13)
14
15type reflectMessageInfo struct {
16	fields map[pref.FieldNumber]*fieldInfo
17	oneofs map[pref.Name]*oneofInfo
18
19	// denseFields is a subset of fields where:
20	//	0 < fieldDesc.Number() < len(denseFields)
21	// It provides faster access to the fieldInfo, but may be incomplete.
22	denseFields []*fieldInfo
23
24	// rangeInfos is a list of all fields (not belonging to a oneof) and oneofs.
25	rangeInfos []interface{} // either *fieldInfo or *oneofInfo
26
27	getUnknown   func(pointer) pref.RawFields
28	setUnknown   func(pointer, pref.RawFields)
29	extensionMap func(pointer) *extensionMap
30
31	nilMessage atomicNilMessage
32}
33
34// makeReflectFuncs generates the set of functions to support reflection.
35func (mi *MessageInfo) makeReflectFuncs(t reflect.Type, si structInfo) {
36	mi.makeKnownFieldsFunc(si)
37	mi.makeUnknownFieldsFunc(t, si)
38	mi.makeExtensionFieldsFunc(t, si)
39}
40
41// makeKnownFieldsFunc generates functions for operations that can be performed
42// on each protobuf message field. It takes in a reflect.Type representing the
43// Go struct and matches message fields with struct fields.
44//
45// This code assumes that the struct is well-formed and panics if there are
46// any discrepancies.
47func (mi *MessageInfo) makeKnownFieldsFunc(si structInfo) {
48	mi.fields = map[pref.FieldNumber]*fieldInfo{}
49	md := mi.Desc
50	fds := md.Fields()
51	for i := 0; i < fds.Len(); i++ {
52		fd := fds.Get(i)
53		fs := si.fieldsByNumber[fd.Number()]
54		var fi fieldInfo
55		switch {
56		case fd.ContainingOneof() != nil && !fd.ContainingOneof().IsSynthetic():
57			fi = fieldInfoForOneof(fd, si.oneofsByName[fd.ContainingOneof().Name()], mi.Exporter, si.oneofWrappersByNumber[fd.Number()])
58		case fd.IsMap():
59			fi = fieldInfoForMap(fd, fs, mi.Exporter)
60		case fd.IsList():
61			fi = fieldInfoForList(fd, fs, mi.Exporter)
62		case fd.IsWeak():
63			fi = fieldInfoForWeakMessage(fd, si.weakOffset)
64		case fd.Kind() == pref.MessageKind || fd.Kind() == pref.GroupKind:
65			fi = fieldInfoForMessage(fd, fs, mi.Exporter)
66		default:
67			fi = fieldInfoForScalar(fd, fs, mi.Exporter)
68		}
69		mi.fields[fd.Number()] = &fi
70	}
71
72	mi.oneofs = map[pref.Name]*oneofInfo{}
73	for i := 0; i < md.Oneofs().Len(); i++ {
74		od := md.Oneofs().Get(i)
75		mi.oneofs[od.Name()] = makeOneofInfo(od, si, mi.Exporter)
76	}
77
78	mi.denseFields = make([]*fieldInfo, fds.Len()*2)
79	for i := 0; i < fds.Len(); i++ {
80		if fd := fds.Get(i); int(fd.Number()) < len(mi.denseFields) {
81			mi.denseFields[fd.Number()] = mi.fields[fd.Number()]
82		}
83	}
84
85	for i := 0; i < fds.Len(); {
86		fd := fds.Get(i)
87		if od := fd.ContainingOneof(); od != nil && !od.IsSynthetic() {
88			mi.rangeInfos = append(mi.rangeInfos, mi.oneofs[od.Name()])
89			i += od.Fields().Len()
90		} else {
91			mi.rangeInfos = append(mi.rangeInfos, mi.fields[fd.Number()])
92			i++
93		}
94	}
95}
96
97func (mi *MessageInfo) makeUnknownFieldsFunc(t reflect.Type, si structInfo) {
98	mi.getUnknown = func(pointer) pref.RawFields { return nil }
99	mi.setUnknown = func(pointer, pref.RawFields) { return }
100	if si.unknownOffset.IsValid() {
101		mi.getUnknown = func(p pointer) pref.RawFields {
102			if p.IsNil() {
103				return nil
104			}
105			rv := p.Apply(si.unknownOffset).AsValueOf(unknownFieldsType)
106			return pref.RawFields(*rv.Interface().(*[]byte))
107		}
108		mi.setUnknown = func(p pointer, b pref.RawFields) {
109			if p.IsNil() {
110				panic("invalid SetUnknown on nil Message")
111			}
112			rv := p.Apply(si.unknownOffset).AsValueOf(unknownFieldsType)
113			*rv.Interface().(*[]byte) = []byte(b)
114		}
115	} else {
116		mi.getUnknown = func(pointer) pref.RawFields {
117			return nil
118		}
119		mi.setUnknown = func(p pointer, _ pref.RawFields) {
120			if p.IsNil() {
121				panic("invalid SetUnknown on nil Message")
122			}
123		}
124	}
125}
126
127func (mi *MessageInfo) makeExtensionFieldsFunc(t reflect.Type, si structInfo) {
128	if si.extensionOffset.IsValid() {
129		mi.extensionMap = func(p pointer) *extensionMap {
130			if p.IsNil() {
131				return (*extensionMap)(nil)
132			}
133			v := p.Apply(si.extensionOffset).AsValueOf(extensionFieldsType)
134			return (*extensionMap)(v.Interface().(*map[int32]ExtensionField))
135		}
136	} else {
137		mi.extensionMap = func(pointer) *extensionMap {
138			return (*extensionMap)(nil)
139		}
140	}
141}
142
143type extensionMap map[int32]ExtensionField
144
145func (m *extensionMap) Range(f func(pref.FieldDescriptor, pref.Value) bool) {
146	if m != nil {
147		for _, x := range *m {
148			xd := x.Type().TypeDescriptor()
149			v := x.Value()
150			if xd.IsList() && v.List().Len() == 0 {
151				continue
152			}
153			if !f(xd, v) {
154				return
155			}
156		}
157	}
158}
159func (m *extensionMap) Has(xt pref.ExtensionType) (ok bool) {
160	if m == nil {
161		return false
162	}
163	xd := xt.TypeDescriptor()
164	x, ok := (*m)[int32(xd.Number())]
165	if !ok {
166		return false
167	}
168	switch {
169	case xd.IsList():
170		return x.Value().List().Len() > 0
171	case xd.IsMap():
172		return x.Value().Map().Len() > 0
173	case xd.Message() != nil:
174		return x.Value().Message().IsValid()
175	}
176	return true
177}
178func (m *extensionMap) Clear(xt pref.ExtensionType) {
179	delete(*m, int32(xt.TypeDescriptor().Number()))
180}
181func (m *extensionMap) Get(xt pref.ExtensionType) pref.Value {
182	xd := xt.TypeDescriptor()
183	if m != nil {
184		if x, ok := (*m)[int32(xd.Number())]; ok {
185			return x.Value()
186		}
187	}
188	return xt.Zero()
189}
190func (m *extensionMap) Set(xt pref.ExtensionType, v pref.Value) {
191	xd := xt.TypeDescriptor()
192	isValid := true
193	switch {
194	case !xt.IsValidValue(v):
195		isValid = false
196	case xd.IsList():
197		isValid = v.List().IsValid()
198	case xd.IsMap():
199		isValid = v.Map().IsValid()
200	case xd.Message() != nil:
201		isValid = v.Message().IsValid()
202	}
203	if !isValid {
204		panic(fmt.Sprintf("%v: assigning invalid value", xt.TypeDescriptor().FullName()))
205	}
206
207	if *m == nil {
208		*m = make(map[int32]ExtensionField)
209	}
210	var x ExtensionField
211	x.Set(xt, v)
212	(*m)[int32(xd.Number())] = x
213}
214func (m *extensionMap) Mutable(xt pref.ExtensionType) pref.Value {
215	xd := xt.TypeDescriptor()
216	if xd.Kind() != pref.MessageKind && xd.Kind() != pref.GroupKind && !xd.IsList() && !xd.IsMap() {
217		panic("invalid Mutable on field with non-composite type")
218	}
219	if x, ok := (*m)[int32(xd.Number())]; ok {
220		return x.Value()
221	}
222	v := xt.New()
223	m.Set(xt, v)
224	return v
225}
226
227// MessageState is a data structure that is nested as the first field in a
228// concrete message. It provides a way to implement the ProtoReflect method
229// in an allocation-free way without needing to have a shadow Go type generated
230// for every message type. This technique only works using unsafe.
231//
232//
233// Example generated code:
234//
235//	type M struct {
236//		state protoimpl.MessageState
237//
238//		Field1 int32
239//		Field2 string
240//		Field3 *BarMessage
241//		...
242//	}
243//
244//	func (m *M) ProtoReflect() protoreflect.Message {
245//		mi := &file_fizz_buzz_proto_msgInfos[5]
246//		if protoimpl.UnsafeEnabled && m != nil {
247//			ms := protoimpl.X.MessageStateOf(Pointer(m))
248//			if ms.LoadMessageInfo() == nil {
249//				ms.StoreMessageInfo(mi)
250//			}
251//			return ms
252//		}
253//		return mi.MessageOf(m)
254//	}
255//
256// The MessageState type holds a *MessageInfo, which must be atomically set to
257// the message info associated with a given message instance.
258// By unsafely converting a *M into a *MessageState, the MessageState object
259// has access to all the information needed to implement protobuf reflection.
260// It has access to the message info as its first field, and a pointer to the
261// MessageState is identical to a pointer to the concrete message value.
262//
263//
264// Requirements:
265//	• The type M must implement protoreflect.ProtoMessage.
266//	• The address of m must not be nil.
267//	• The address of m and the address of m.state must be equal,
268//	even though they are different Go types.
269type MessageState struct {
270	pragma.NoUnkeyedLiterals
271	pragma.DoNotCompare
272	pragma.DoNotCopy
273
274	atomicMessageInfo *MessageInfo
275}
276
277type messageState MessageState
278
279var (
280	_ pref.Message = (*messageState)(nil)
281	_ unwrapper    = (*messageState)(nil)
282)
283
284// messageDataType is a tuple of a pointer to the message data and
285// a pointer to the message type. It is a generalized way of providing a
286// reflective view over a message instance. The disadvantage of this approach
287// is the need to allocate this tuple of 16B.
288type messageDataType struct {
289	p  pointer
290	mi *MessageInfo
291}
292
293type (
294	messageReflectWrapper messageDataType
295	messageIfaceWrapper   messageDataType
296)
297
298var (
299	_ pref.Message      = (*messageReflectWrapper)(nil)
300	_ unwrapper         = (*messageReflectWrapper)(nil)
301	_ pref.ProtoMessage = (*messageIfaceWrapper)(nil)
302	_ unwrapper         = (*messageIfaceWrapper)(nil)
303)
304
305// MessageOf returns a reflective view over a message. The input must be a
306// pointer to a named Go struct. If the provided type has a ProtoReflect method,
307// it must be implemented by calling this method.
308func (mi *MessageInfo) MessageOf(m interface{}) pref.Message {
309	// TODO: Switch the input to be an opaque Pointer.
310	if reflect.TypeOf(m) != mi.GoReflectType {
311		panic(fmt.Sprintf("type mismatch: got %T, want %v", m, mi.GoReflectType))
312	}
313	p := pointerOfIface(m)
314	if p.IsNil() {
315		return mi.nilMessage.Init(mi)
316	}
317	return &messageReflectWrapper{p, mi}
318}
319
320func (m *messageReflectWrapper) pointer() pointer          { return m.p }
321func (m *messageReflectWrapper) messageInfo() *MessageInfo { return m.mi }
322
323func (m *messageIfaceWrapper) ProtoReflect() pref.Message {
324	return (*messageReflectWrapper)(m)
325}
326func (m *messageIfaceWrapper) protoUnwrap() interface{} {
327	return m.p.AsIfaceOf(m.mi.GoReflectType.Elem())
328}
329
330// checkField verifies that the provided field descriptor is valid.
331// Exactly one of the returned values is populated.
332func (mi *MessageInfo) checkField(fd pref.FieldDescriptor) (*fieldInfo, pref.ExtensionType) {
333	var fi *fieldInfo
334	if n := fd.Number(); 0 < n && int(n) < len(mi.denseFields) {
335		fi = mi.denseFields[n]
336	} else {
337		fi = mi.fields[n]
338	}
339	if fi != nil {
340		if fi.fieldDesc != fd {
341			if got, want := fd.FullName(), fi.fieldDesc.FullName(); got != want {
342				panic(fmt.Sprintf("mismatching field: got %v, want %v", got, want))
343			}
344			panic(fmt.Sprintf("mismatching field: %v", fd.FullName()))
345		}
346		return fi, nil
347	}
348
349	if fd.IsExtension() {
350		if got, want := fd.ContainingMessage().FullName(), mi.Desc.FullName(); got != want {
351			// TODO: Should this be exact containing message descriptor match?
352			panic(fmt.Sprintf("extension %v has mismatching containing message: got %v, want %v", fd.FullName(), got, want))
353		}
354		if !mi.Desc.ExtensionRanges().Has(fd.Number()) {
355			panic(fmt.Sprintf("extension %v extends %v outside the extension range", fd.FullName(), mi.Desc.FullName()))
356		}
357		xtd, ok := fd.(pref.ExtensionTypeDescriptor)
358		if !ok {
359			panic(fmt.Sprintf("extension %v does not implement protoreflect.ExtensionTypeDescriptor", fd.FullName()))
360		}
361		return nil, xtd.Type()
362	}
363	panic(fmt.Sprintf("field %v is invalid", fd.FullName()))
364}
365