1// Copyright 2019 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package impl
6
7import (
8	"fmt"
9	"reflect"
10
11	"google.golang.org/protobuf/internal/detrand"
12	"google.golang.org/protobuf/internal/pragma"
13	pref "google.golang.org/protobuf/reflect/protoreflect"
14)
15
16type reflectMessageInfo struct {
17	fields map[pref.FieldNumber]*fieldInfo
18	oneofs map[pref.Name]*oneofInfo
19
20	// fieldTypes contains the zero value of an enum or message field.
21	// For lists, it contains the element type.
22	// For maps, it contains the entry value type.
23	fieldTypes map[pref.FieldNumber]interface{}
24
25	// denseFields is a subset of fields where:
26	//	0 < fieldDesc.Number() < len(denseFields)
27	// It provides faster access to the fieldInfo, but may be incomplete.
28	denseFields []*fieldInfo
29
30	// rangeInfos is a list of all fields (not belonging to a oneof) and oneofs.
31	rangeInfos []interface{} // either *fieldInfo or *oneofInfo
32
33	getUnknown   func(pointer) pref.RawFields
34	setUnknown   func(pointer, pref.RawFields)
35	extensionMap func(pointer) *extensionMap
36
37	nilMessage atomicNilMessage
38}
39
40// makeReflectFuncs generates the set of functions to support reflection.
41func (mi *MessageInfo) makeReflectFuncs(t reflect.Type, si structInfo) {
42	mi.makeKnownFieldsFunc(si)
43	mi.makeUnknownFieldsFunc(t, si)
44	mi.makeExtensionFieldsFunc(t, si)
45	mi.makeFieldTypes(si)
46}
47
48// makeKnownFieldsFunc generates functions for operations that can be performed
49// on each protobuf message field. It takes in a reflect.Type representing the
50// Go struct and matches message fields with struct fields.
51//
52// This code assumes that the struct is well-formed and panics if there are
53// any discrepancies.
54func (mi *MessageInfo) makeKnownFieldsFunc(si structInfo) {
55	mi.fields = map[pref.FieldNumber]*fieldInfo{}
56	md := mi.Desc
57	fds := md.Fields()
58	for i := 0; i < fds.Len(); i++ {
59		fd := fds.Get(i)
60		fs := si.fieldsByNumber[fd.Number()]
61		isOneof := fd.ContainingOneof() != nil && !fd.ContainingOneof().IsSynthetic()
62		if isOneof {
63			fs = si.oneofsByName[fd.ContainingOneof().Name()]
64		}
65		var fi fieldInfo
66		switch {
67		case fs.Type == nil:
68			fi = fieldInfoForMissing(fd) // never occurs for officially generated message types
69		case isOneof:
70			fi = fieldInfoForOneof(fd, fs, mi.Exporter, si.oneofWrappersByNumber[fd.Number()])
71		case fd.IsMap():
72			fi = fieldInfoForMap(fd, fs, mi.Exporter)
73		case fd.IsList():
74			fi = fieldInfoForList(fd, fs, mi.Exporter)
75		case fd.IsWeak():
76			fi = fieldInfoForWeakMessage(fd, si.weakOffset)
77		case fd.Message() != nil:
78			fi = fieldInfoForMessage(fd, fs, mi.Exporter)
79		default:
80			fi = fieldInfoForScalar(fd, fs, mi.Exporter)
81		}
82		mi.fields[fd.Number()] = &fi
83	}
84
85	mi.oneofs = map[pref.Name]*oneofInfo{}
86	for i := 0; i < md.Oneofs().Len(); i++ {
87		od := md.Oneofs().Get(i)
88		mi.oneofs[od.Name()] = makeOneofInfo(od, si, mi.Exporter)
89	}
90
91	mi.denseFields = make([]*fieldInfo, fds.Len()*2)
92	for i := 0; i < fds.Len(); i++ {
93		if fd := fds.Get(i); int(fd.Number()) < len(mi.denseFields) {
94			mi.denseFields[fd.Number()] = mi.fields[fd.Number()]
95		}
96	}
97
98	for i := 0; i < fds.Len(); {
99		fd := fds.Get(i)
100		if od := fd.ContainingOneof(); od != nil && !od.IsSynthetic() {
101			mi.rangeInfos = append(mi.rangeInfos, mi.oneofs[od.Name()])
102			i += od.Fields().Len()
103		} else {
104			mi.rangeInfos = append(mi.rangeInfos, mi.fields[fd.Number()])
105			i++
106		}
107	}
108
109	// Introduce instability to iteration order, but keep it deterministic.
110	if len(mi.rangeInfos) > 1 && detrand.Bool() {
111		i := detrand.Intn(len(mi.rangeInfos) - 1)
112		mi.rangeInfos[i], mi.rangeInfos[i+1] = mi.rangeInfos[i+1], mi.rangeInfos[i]
113	}
114}
115
116func (mi *MessageInfo) makeUnknownFieldsFunc(t reflect.Type, si structInfo) {
117	switch {
118	case si.unknownOffset.IsValid() && si.unknownType == unknownFieldsAType:
119		// Handle as []byte.
120		mi.getUnknown = func(p pointer) pref.RawFields {
121			if p.IsNil() {
122				return nil
123			}
124			return *p.Apply(mi.unknownOffset).Bytes()
125		}
126		mi.setUnknown = func(p pointer, b pref.RawFields) {
127			if p.IsNil() {
128				panic("invalid SetUnknown on nil Message")
129			}
130			*p.Apply(mi.unknownOffset).Bytes() = b
131		}
132	case si.unknownOffset.IsValid() && si.unknownType == unknownFieldsBType:
133		// Handle as *[]byte.
134		mi.getUnknown = func(p pointer) pref.RawFields {
135			if p.IsNil() {
136				return nil
137			}
138			bp := p.Apply(mi.unknownOffset).BytesPtr()
139			if *bp == nil {
140				return nil
141			}
142			return **bp
143		}
144		mi.setUnknown = func(p pointer, b pref.RawFields) {
145			if p.IsNil() {
146				panic("invalid SetUnknown on nil Message")
147			}
148			bp := p.Apply(mi.unknownOffset).BytesPtr()
149			if *bp == nil {
150				*bp = new([]byte)
151			}
152			**bp = b
153		}
154	default:
155		mi.getUnknown = func(pointer) pref.RawFields {
156			return nil
157		}
158		mi.setUnknown = func(p pointer, _ pref.RawFields) {
159			if p.IsNil() {
160				panic("invalid SetUnknown on nil Message")
161			}
162		}
163	}
164}
165
166func (mi *MessageInfo) makeExtensionFieldsFunc(t reflect.Type, si structInfo) {
167	if si.extensionOffset.IsValid() {
168		mi.extensionMap = func(p pointer) *extensionMap {
169			if p.IsNil() {
170				return (*extensionMap)(nil)
171			}
172			v := p.Apply(si.extensionOffset).AsValueOf(extensionFieldsType)
173			return (*extensionMap)(v.Interface().(*map[int32]ExtensionField))
174		}
175	} else {
176		mi.extensionMap = func(pointer) *extensionMap {
177			return (*extensionMap)(nil)
178		}
179	}
180}
181func (mi *MessageInfo) makeFieldTypes(si structInfo) {
182	md := mi.Desc
183	fds := md.Fields()
184	for i := 0; i < fds.Len(); i++ {
185		var ft reflect.Type
186		fd := fds.Get(i)
187		fs := si.fieldsByNumber[fd.Number()]
188		isOneof := fd.ContainingOneof() != nil && !fd.ContainingOneof().IsSynthetic()
189		if isOneof {
190			fs = si.oneofsByName[fd.ContainingOneof().Name()]
191		}
192		var isMessage bool
193		switch {
194		case fs.Type == nil:
195			continue // never occurs for officially generated message types
196		case isOneof:
197			if fd.Enum() != nil || fd.Message() != nil {
198				ft = si.oneofWrappersByNumber[fd.Number()].Field(0).Type
199			}
200		case fd.IsMap():
201			if fd.MapValue().Enum() != nil || fd.MapValue().Message() != nil {
202				ft = fs.Type.Elem()
203			}
204			isMessage = fd.MapValue().Message() != nil
205		case fd.IsList():
206			if fd.Enum() != nil || fd.Message() != nil {
207				ft = fs.Type.Elem()
208			}
209			isMessage = fd.Message() != nil
210		case fd.Enum() != nil:
211			ft = fs.Type
212			if fd.HasPresence() && ft.Kind() == reflect.Ptr {
213				ft = ft.Elem()
214			}
215		case fd.Message() != nil:
216			ft = fs.Type
217			if fd.IsWeak() {
218				ft = nil
219			}
220			isMessage = true
221		}
222		if isMessage && ft != nil && ft.Kind() != reflect.Ptr {
223			ft = reflect.PtrTo(ft) // never occurs for officially generated message types
224		}
225		if ft != nil {
226			if mi.fieldTypes == nil {
227				mi.fieldTypes = make(map[pref.FieldNumber]interface{})
228			}
229			mi.fieldTypes[fd.Number()] = reflect.Zero(ft).Interface()
230		}
231	}
232}
233
234type extensionMap map[int32]ExtensionField
235
236func (m *extensionMap) Range(f func(pref.FieldDescriptor, pref.Value) bool) {
237	if m != nil {
238		for _, x := range *m {
239			xd := x.Type().TypeDescriptor()
240			v := x.Value()
241			if xd.IsList() && v.List().Len() == 0 {
242				continue
243			}
244			if !f(xd, v) {
245				return
246			}
247		}
248	}
249}
250func (m *extensionMap) Has(xt pref.ExtensionType) (ok bool) {
251	if m == nil {
252		return false
253	}
254	xd := xt.TypeDescriptor()
255	x, ok := (*m)[int32(xd.Number())]
256	if !ok {
257		return false
258	}
259	switch {
260	case xd.IsList():
261		return x.Value().List().Len() > 0
262	case xd.IsMap():
263		return x.Value().Map().Len() > 0
264	case xd.Message() != nil:
265		return x.Value().Message().IsValid()
266	}
267	return true
268}
269func (m *extensionMap) Clear(xt pref.ExtensionType) {
270	delete(*m, int32(xt.TypeDescriptor().Number()))
271}
272func (m *extensionMap) Get(xt pref.ExtensionType) pref.Value {
273	xd := xt.TypeDescriptor()
274	if m != nil {
275		if x, ok := (*m)[int32(xd.Number())]; ok {
276			return x.Value()
277		}
278	}
279	return xt.Zero()
280}
281func (m *extensionMap) Set(xt pref.ExtensionType, v pref.Value) {
282	xd := xt.TypeDescriptor()
283	isValid := true
284	switch {
285	case !xt.IsValidValue(v):
286		isValid = false
287	case xd.IsList():
288		isValid = v.List().IsValid()
289	case xd.IsMap():
290		isValid = v.Map().IsValid()
291	case xd.Message() != nil:
292		isValid = v.Message().IsValid()
293	}
294	if !isValid {
295		panic(fmt.Sprintf("%v: assigning invalid value", xt.TypeDescriptor().FullName()))
296	}
297
298	if *m == nil {
299		*m = make(map[int32]ExtensionField)
300	}
301	var x ExtensionField
302	x.Set(xt, v)
303	(*m)[int32(xd.Number())] = x
304}
305func (m *extensionMap) Mutable(xt pref.ExtensionType) pref.Value {
306	xd := xt.TypeDescriptor()
307	if xd.Kind() != pref.MessageKind && xd.Kind() != pref.GroupKind && !xd.IsList() && !xd.IsMap() {
308		panic("invalid Mutable on field with non-composite type")
309	}
310	if x, ok := (*m)[int32(xd.Number())]; ok {
311		return x.Value()
312	}
313	v := xt.New()
314	m.Set(xt, v)
315	return v
316}
317
318// MessageState is a data structure that is nested as the first field in a
319// concrete message. It provides a way to implement the ProtoReflect method
320// in an allocation-free way without needing to have a shadow Go type generated
321// for every message type. This technique only works using unsafe.
322//
323//
324// Example generated code:
325//
326//	type M struct {
327//		state protoimpl.MessageState
328//
329//		Field1 int32
330//		Field2 string
331//		Field3 *BarMessage
332//		...
333//	}
334//
335//	func (m *M) ProtoReflect() protoreflect.Message {
336//		mi := &file_fizz_buzz_proto_msgInfos[5]
337//		if protoimpl.UnsafeEnabled && m != nil {
338//			ms := protoimpl.X.MessageStateOf(Pointer(m))
339//			if ms.LoadMessageInfo() == nil {
340//				ms.StoreMessageInfo(mi)
341//			}
342//			return ms
343//		}
344//		return mi.MessageOf(m)
345//	}
346//
347// The MessageState type holds a *MessageInfo, which must be atomically set to
348// the message info associated with a given message instance.
349// By unsafely converting a *M into a *MessageState, the MessageState object
350// has access to all the information needed to implement protobuf reflection.
351// It has access to the message info as its first field, and a pointer to the
352// MessageState is identical to a pointer to the concrete message value.
353//
354//
355// Requirements:
356//	• The type M must implement protoreflect.ProtoMessage.
357//	• The address of m must not be nil.
358//	• The address of m and the address of m.state must be equal,
359//	even though they are different Go types.
360type MessageState struct {
361	pragma.NoUnkeyedLiterals
362	pragma.DoNotCompare
363	pragma.DoNotCopy
364
365	atomicMessageInfo *MessageInfo
366}
367
368type messageState MessageState
369
370var (
371	_ pref.Message = (*messageState)(nil)
372	_ unwrapper    = (*messageState)(nil)
373)
374
375// messageDataType is a tuple of a pointer to the message data and
376// a pointer to the message type. It is a generalized way of providing a
377// reflective view over a message instance. The disadvantage of this approach
378// is the need to allocate this tuple of 16B.
379type messageDataType struct {
380	p  pointer
381	mi *MessageInfo
382}
383
384type (
385	messageReflectWrapper messageDataType
386	messageIfaceWrapper   messageDataType
387)
388
389var (
390	_ pref.Message      = (*messageReflectWrapper)(nil)
391	_ unwrapper         = (*messageReflectWrapper)(nil)
392	_ pref.ProtoMessage = (*messageIfaceWrapper)(nil)
393	_ unwrapper         = (*messageIfaceWrapper)(nil)
394)
395
396// MessageOf returns a reflective view over a message. The input must be a
397// pointer to a named Go struct. If the provided type has a ProtoReflect method,
398// it must be implemented by calling this method.
399func (mi *MessageInfo) MessageOf(m interface{}) pref.Message {
400	if reflect.TypeOf(m) != mi.GoReflectType {
401		panic(fmt.Sprintf("type mismatch: got %T, want %v", m, mi.GoReflectType))
402	}
403	p := pointerOfIface(m)
404	if p.IsNil() {
405		return mi.nilMessage.Init(mi)
406	}
407	return &messageReflectWrapper{p, mi}
408}
409
410func (m *messageReflectWrapper) pointer() pointer          { return m.p }
411func (m *messageReflectWrapper) messageInfo() *MessageInfo { return m.mi }
412
413// Reset implements the v1 proto.Message.Reset method.
414func (m *messageIfaceWrapper) Reset() {
415	if mr, ok := m.protoUnwrap().(interface{ Reset() }); ok {
416		mr.Reset()
417		return
418	}
419	rv := reflect.ValueOf(m.protoUnwrap())
420	if rv.Kind() == reflect.Ptr && !rv.IsNil() {
421		rv.Elem().Set(reflect.Zero(rv.Type().Elem()))
422	}
423}
424func (m *messageIfaceWrapper) ProtoReflect() pref.Message {
425	return (*messageReflectWrapper)(m)
426}
427func (m *messageIfaceWrapper) protoUnwrap() interface{} {
428	return m.p.AsIfaceOf(m.mi.GoReflectType.Elem())
429}
430
431// checkField verifies that the provided field descriptor is valid.
432// Exactly one of the returned values is populated.
433func (mi *MessageInfo) checkField(fd pref.FieldDescriptor) (*fieldInfo, pref.ExtensionType) {
434	var fi *fieldInfo
435	if n := fd.Number(); 0 < n && int(n) < len(mi.denseFields) {
436		fi = mi.denseFields[n]
437	} else {
438		fi = mi.fields[n]
439	}
440	if fi != nil {
441		if fi.fieldDesc != fd {
442			if got, want := fd.FullName(), fi.fieldDesc.FullName(); got != want {
443				panic(fmt.Sprintf("mismatching field: got %v, want %v", got, want))
444			}
445			panic(fmt.Sprintf("mismatching field: %v", fd.FullName()))
446		}
447		return fi, nil
448	}
449
450	if fd.IsExtension() {
451		if got, want := fd.ContainingMessage().FullName(), mi.Desc.FullName(); got != want {
452			// TODO: Should this be exact containing message descriptor match?
453			panic(fmt.Sprintf("extension %v has mismatching containing message: got %v, want %v", fd.FullName(), got, want))
454		}
455		if !mi.Desc.ExtensionRanges().Has(fd.Number()) {
456			panic(fmt.Sprintf("extension %v extends %v outside the extension range", fd.FullName(), mi.Desc.FullName()))
457		}
458		xtd, ok := fd.(pref.ExtensionTypeDescriptor)
459		if !ok {
460			panic(fmt.Sprintf("extension %v does not implement protoreflect.ExtensionTypeDescriptor", fd.FullName()))
461		}
462		return nil, xtd.Type()
463	}
464	panic(fmt.Sprintf("field %v is invalid", fd.FullName()))
465}
466