1// Copyright 2018 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package impl
6
7import (
8	"fmt"
9	"math"
10	"reflect"
11	"sync"
12
13	"google.golang.org/protobuf/internal/flags"
14	pref "google.golang.org/protobuf/reflect/protoreflect"
15	preg "google.golang.org/protobuf/reflect/protoregistry"
16)
17
18type fieldInfo struct {
19	fieldDesc pref.FieldDescriptor
20
21	// These fields are used for protobuf reflection support.
22	has        func(pointer) bool
23	clear      func(pointer)
24	get        func(pointer) pref.Value
25	set        func(pointer, pref.Value)
26	mutable    func(pointer) pref.Value
27	newMessage func() pref.Message
28	newField   func() pref.Value
29}
30
31func fieldInfoForOneof(fd pref.FieldDescriptor, fs reflect.StructField, x exporter, ot reflect.Type) fieldInfo {
32	ft := fs.Type
33	if ft.Kind() != reflect.Interface {
34		panic(fmt.Sprintf("field %v has invalid type: got %v, want interface kind", fd.FullName(), ft))
35	}
36	if ot.Kind() != reflect.Struct {
37		panic(fmt.Sprintf("field %v has invalid type: got %v, want struct kind", fd.FullName(), ot))
38	}
39	if !reflect.PtrTo(ot).Implements(ft) {
40		panic(fmt.Sprintf("field %v has invalid type: %v does not implement %v", fd.FullName(), ot, ft))
41	}
42	conv := NewConverter(ot.Field(0).Type, fd)
43	isMessage := fd.Message() != nil
44
45	// TODO: Implement unsafe fast path?
46	fieldOffset := offsetOf(fs, x)
47	return fieldInfo{
48		// NOTE: The logic below intentionally assumes that oneof fields are
49		// well-formatted. That is, the oneof interface never contains a
50		// typed nil pointer to one of the wrapper structs.
51
52		fieldDesc: fd,
53		has: func(p pointer) bool {
54			if p.IsNil() {
55				return false
56			}
57			rv := p.Apply(fieldOffset).AsValueOf(fs.Type).Elem()
58			if rv.IsNil() || rv.Elem().Type().Elem() != ot || rv.Elem().IsNil() {
59				return false
60			}
61			return true
62		},
63		clear: func(p pointer) {
64			rv := p.Apply(fieldOffset).AsValueOf(fs.Type).Elem()
65			if rv.IsNil() || rv.Elem().Type().Elem() != ot {
66				// NOTE: We intentionally don't check for rv.Elem().IsNil()
67				// so that (*OneofWrapperType)(nil) gets cleared to nil.
68				return
69			}
70			rv.Set(reflect.Zero(rv.Type()))
71		},
72		get: func(p pointer) pref.Value {
73			if p.IsNil() {
74				return conv.Zero()
75			}
76			rv := p.Apply(fieldOffset).AsValueOf(fs.Type).Elem()
77			if rv.IsNil() || rv.Elem().Type().Elem() != ot || rv.Elem().IsNil() {
78				return conv.Zero()
79			}
80			rv = rv.Elem().Elem().Field(0)
81			return conv.PBValueOf(rv)
82		},
83		set: func(p pointer, v pref.Value) {
84			rv := p.Apply(fieldOffset).AsValueOf(fs.Type).Elem()
85			if rv.IsNil() || rv.Elem().Type().Elem() != ot || rv.Elem().IsNil() {
86				rv.Set(reflect.New(ot))
87			}
88			rv = rv.Elem().Elem().Field(0)
89			rv.Set(conv.GoValueOf(v))
90		},
91		mutable: func(p pointer) pref.Value {
92			if !isMessage {
93				panic(fmt.Sprintf("field %v with invalid Mutable call on field with non-composite type", fd.FullName()))
94			}
95			rv := p.Apply(fieldOffset).AsValueOf(fs.Type).Elem()
96			if rv.IsNil() || rv.Elem().Type().Elem() != ot || rv.Elem().IsNil() {
97				rv.Set(reflect.New(ot))
98			}
99			rv = rv.Elem().Elem().Field(0)
100			if rv.IsNil() {
101				rv.Set(conv.GoValueOf(pref.ValueOfMessage(conv.New().Message())))
102			}
103			return conv.PBValueOf(rv)
104		},
105		newMessage: func() pref.Message {
106			return conv.New().Message()
107		},
108		newField: func() pref.Value {
109			return conv.New()
110		},
111	}
112}
113
114func fieldInfoForMap(fd pref.FieldDescriptor, fs reflect.StructField, x exporter) fieldInfo {
115	ft := fs.Type
116	if ft.Kind() != reflect.Map {
117		panic(fmt.Sprintf("field %v has invalid type: got %v, want map kind", fd.FullName(), ft))
118	}
119	conv := NewConverter(ft, fd)
120
121	// TODO: Implement unsafe fast path?
122	fieldOffset := offsetOf(fs, x)
123	return fieldInfo{
124		fieldDesc: fd,
125		has: func(p pointer) bool {
126			if p.IsNil() {
127				return false
128			}
129			rv := p.Apply(fieldOffset).AsValueOf(fs.Type).Elem()
130			return rv.Len() > 0
131		},
132		clear: func(p pointer) {
133			rv := p.Apply(fieldOffset).AsValueOf(fs.Type).Elem()
134			rv.Set(reflect.Zero(rv.Type()))
135		},
136		get: func(p pointer) pref.Value {
137			if p.IsNil() {
138				return conv.Zero()
139			}
140			rv := p.Apply(fieldOffset).AsValueOf(fs.Type).Elem()
141			if rv.Len() == 0 {
142				return conv.Zero()
143			}
144			return conv.PBValueOf(rv)
145		},
146		set: func(p pointer, v pref.Value) {
147			rv := p.Apply(fieldOffset).AsValueOf(fs.Type).Elem()
148			pv := conv.GoValueOf(v)
149			if pv.IsNil() {
150				panic(fmt.Sprintf("map field %v cannot be set with read-only value", fd.FullName()))
151			}
152			rv.Set(pv)
153		},
154		mutable: func(p pointer) pref.Value {
155			v := p.Apply(fieldOffset).AsValueOf(fs.Type).Elem()
156			if v.IsNil() {
157				v.Set(reflect.MakeMap(fs.Type))
158			}
159			return conv.PBValueOf(v)
160		},
161		newField: func() pref.Value {
162			return conv.New()
163		},
164	}
165}
166
167func fieldInfoForList(fd pref.FieldDescriptor, fs reflect.StructField, x exporter) fieldInfo {
168	ft := fs.Type
169	if ft.Kind() != reflect.Slice {
170		panic(fmt.Sprintf("field %v has invalid type: got %v, want slice kind", fd.FullName(), ft))
171	}
172	conv := NewConverter(reflect.PtrTo(ft), fd)
173
174	// TODO: Implement unsafe fast path?
175	fieldOffset := offsetOf(fs, x)
176	return fieldInfo{
177		fieldDesc: fd,
178		has: func(p pointer) bool {
179			if p.IsNil() {
180				return false
181			}
182			rv := p.Apply(fieldOffset).AsValueOf(fs.Type).Elem()
183			return rv.Len() > 0
184		},
185		clear: func(p pointer) {
186			rv := p.Apply(fieldOffset).AsValueOf(fs.Type).Elem()
187			rv.Set(reflect.Zero(rv.Type()))
188		},
189		get: func(p pointer) pref.Value {
190			if p.IsNil() {
191				return conv.Zero()
192			}
193			rv := p.Apply(fieldOffset).AsValueOf(fs.Type)
194			if rv.Elem().Len() == 0 {
195				return conv.Zero()
196			}
197			return conv.PBValueOf(rv)
198		},
199		set: func(p pointer, v pref.Value) {
200			rv := p.Apply(fieldOffset).AsValueOf(fs.Type).Elem()
201			pv := conv.GoValueOf(v)
202			if pv.IsNil() {
203				panic(fmt.Sprintf("list field %v cannot be set with read-only value", fd.FullName()))
204			}
205			rv.Set(pv.Elem())
206		},
207		mutable: func(p pointer) pref.Value {
208			v := p.Apply(fieldOffset).AsValueOf(fs.Type)
209			return conv.PBValueOf(v)
210		},
211		newField: func() pref.Value {
212			return conv.New()
213		},
214	}
215}
216
217var (
218	nilBytes   = reflect.ValueOf([]byte(nil))
219	emptyBytes = reflect.ValueOf([]byte{})
220)
221
222func fieldInfoForScalar(fd pref.FieldDescriptor, fs reflect.StructField, x exporter) fieldInfo {
223	ft := fs.Type
224	nullable := fd.HasPresence()
225	isBytes := ft.Kind() == reflect.Slice && ft.Elem().Kind() == reflect.Uint8
226	if nullable {
227		if ft.Kind() != reflect.Ptr && ft.Kind() != reflect.Slice {
228			panic(fmt.Sprintf("field %v has invalid type: got %v, want pointer", fd.FullName(), ft))
229		}
230		if ft.Kind() == reflect.Ptr {
231			ft = ft.Elem()
232		}
233	}
234	conv := NewConverter(ft, fd)
235
236	// TODO: Implement unsafe fast path?
237	fieldOffset := offsetOf(fs, x)
238	return fieldInfo{
239		fieldDesc: fd,
240		has: func(p pointer) bool {
241			if p.IsNil() {
242				return false
243			}
244			rv := p.Apply(fieldOffset).AsValueOf(fs.Type).Elem()
245			if nullable {
246				return !rv.IsNil()
247			}
248			switch rv.Kind() {
249			case reflect.Bool:
250				return rv.Bool()
251			case reflect.Int32, reflect.Int64:
252				return rv.Int() != 0
253			case reflect.Uint32, reflect.Uint64:
254				return rv.Uint() != 0
255			case reflect.Float32, reflect.Float64:
256				return rv.Float() != 0 || math.Signbit(rv.Float())
257			case reflect.String, reflect.Slice:
258				return rv.Len() > 0
259			default:
260				panic(fmt.Sprintf("field %v has invalid type: %v", fd.FullName(), rv.Type())) // should never happen
261			}
262		},
263		clear: func(p pointer) {
264			rv := p.Apply(fieldOffset).AsValueOf(fs.Type).Elem()
265			rv.Set(reflect.Zero(rv.Type()))
266		},
267		get: func(p pointer) pref.Value {
268			if p.IsNil() {
269				return conv.Zero()
270			}
271			rv := p.Apply(fieldOffset).AsValueOf(fs.Type).Elem()
272			if nullable {
273				if rv.IsNil() {
274					return conv.Zero()
275				}
276				if rv.Kind() == reflect.Ptr {
277					rv = rv.Elem()
278				}
279			}
280			return conv.PBValueOf(rv)
281		},
282		set: func(p pointer, v pref.Value) {
283			rv := p.Apply(fieldOffset).AsValueOf(fs.Type).Elem()
284			if nullable && rv.Kind() == reflect.Ptr {
285				if rv.IsNil() {
286					rv.Set(reflect.New(ft))
287				}
288				rv = rv.Elem()
289			}
290			rv.Set(conv.GoValueOf(v))
291			if isBytes && rv.Len() == 0 {
292				if nullable {
293					rv.Set(emptyBytes) // preserve presence
294				} else {
295					rv.Set(nilBytes) // do not preserve presence
296				}
297			}
298		},
299		newField: func() pref.Value {
300			return conv.New()
301		},
302	}
303}
304
305func fieldInfoForWeakMessage(fd pref.FieldDescriptor, weakOffset offset) fieldInfo {
306	if !flags.ProtoLegacy {
307		panic("no support for proto1 weak fields")
308	}
309
310	var once sync.Once
311	var messageType pref.MessageType
312	lazyInit := func() {
313		once.Do(func() {
314			messageName := fd.Message().FullName()
315			messageType, _ = preg.GlobalTypes.FindMessageByName(messageName)
316			if messageType == nil {
317				panic(fmt.Sprintf("weak message %v for field %v is not linked in", messageName, fd.FullName()))
318			}
319		})
320	}
321
322	num := fd.Number()
323	return fieldInfo{
324		fieldDesc: fd,
325		has: func(p pointer) bool {
326			if p.IsNil() {
327				return false
328			}
329			_, ok := p.Apply(weakOffset).WeakFields().get(num)
330			return ok
331		},
332		clear: func(p pointer) {
333			p.Apply(weakOffset).WeakFields().clear(num)
334		},
335		get: func(p pointer) pref.Value {
336			lazyInit()
337			if p.IsNil() {
338				return pref.ValueOfMessage(messageType.Zero())
339			}
340			m, ok := p.Apply(weakOffset).WeakFields().get(num)
341			if !ok {
342				return pref.ValueOfMessage(messageType.Zero())
343			}
344			return pref.ValueOfMessage(m.ProtoReflect())
345		},
346		set: func(p pointer, v pref.Value) {
347			lazyInit()
348			m := v.Message()
349			if m.Descriptor() != messageType.Descriptor() {
350				if got, want := m.Descriptor().FullName(), messageType.Descriptor().FullName(); got != want {
351					panic(fmt.Sprintf("field %v has mismatching message descriptor: got %v, want %v", fd.FullName(), got, want))
352				}
353				panic(fmt.Sprintf("field %v has mismatching message descriptor: %v", fd.FullName(), m.Descriptor().FullName()))
354			}
355			p.Apply(weakOffset).WeakFields().set(num, m.Interface())
356		},
357		mutable: func(p pointer) pref.Value {
358			lazyInit()
359			fs := p.Apply(weakOffset).WeakFields()
360			m, ok := fs.get(num)
361			if !ok {
362				m = messageType.New().Interface()
363				fs.set(num, m)
364			}
365			return pref.ValueOfMessage(m.ProtoReflect())
366		},
367		newMessage: func() pref.Message {
368			lazyInit()
369			return messageType.New()
370		},
371		newField: func() pref.Value {
372			lazyInit()
373			return pref.ValueOfMessage(messageType.New())
374		},
375	}
376}
377
378func fieldInfoForMessage(fd pref.FieldDescriptor, fs reflect.StructField, x exporter) fieldInfo {
379	ft := fs.Type
380	conv := NewConverter(ft, fd)
381
382	// TODO: Implement unsafe fast path?
383	fieldOffset := offsetOf(fs, x)
384	return fieldInfo{
385		fieldDesc: fd,
386		has: func(p pointer) bool {
387			if p.IsNil() {
388				return false
389			}
390			rv := p.Apply(fieldOffset).AsValueOf(fs.Type).Elem()
391			return !rv.IsNil()
392		},
393		clear: func(p pointer) {
394			rv := p.Apply(fieldOffset).AsValueOf(fs.Type).Elem()
395			rv.Set(reflect.Zero(rv.Type()))
396		},
397		get: func(p pointer) pref.Value {
398			if p.IsNil() {
399				return conv.Zero()
400			}
401			rv := p.Apply(fieldOffset).AsValueOf(fs.Type).Elem()
402			return conv.PBValueOf(rv)
403		},
404		set: func(p pointer, v pref.Value) {
405			rv := p.Apply(fieldOffset).AsValueOf(fs.Type).Elem()
406			rv.Set(conv.GoValueOf(v))
407			if rv.IsNil() {
408				panic(fmt.Sprintf("field %v has invalid nil pointer", fd.FullName()))
409			}
410		},
411		mutable: func(p pointer) pref.Value {
412			rv := p.Apply(fieldOffset).AsValueOf(fs.Type).Elem()
413			if rv.IsNil() {
414				rv.Set(conv.GoValueOf(conv.New()))
415			}
416			return conv.PBValueOf(rv)
417		},
418		newMessage: func() pref.Message {
419			return conv.New().Message()
420		},
421		newField: func() pref.Value {
422			return conv.New()
423		},
424	}
425}
426
427type oneofInfo struct {
428	oneofDesc pref.OneofDescriptor
429	which     func(pointer) pref.FieldNumber
430}
431
432func makeOneofInfo(od pref.OneofDescriptor, si structInfo, x exporter) *oneofInfo {
433	oi := &oneofInfo{oneofDesc: od}
434	if od.IsSynthetic() {
435		fs := si.fieldsByNumber[od.Fields().Get(0).Number()]
436		fieldOffset := offsetOf(fs, x)
437		oi.which = func(p pointer) pref.FieldNumber {
438			if p.IsNil() {
439				return 0
440			}
441			rv := p.Apply(fieldOffset).AsValueOf(fs.Type).Elem()
442			if rv.IsNil() { // valid on either *T or []byte
443				return 0
444			}
445			return od.Fields().Get(0).Number()
446		}
447	} else {
448		fs := si.oneofsByName[od.Name()]
449		fieldOffset := offsetOf(fs, x)
450		oi.which = func(p pointer) pref.FieldNumber {
451			if p.IsNil() {
452				return 0
453			}
454			rv := p.Apply(fieldOffset).AsValueOf(fs.Type).Elem()
455			if rv.IsNil() {
456				return 0
457			}
458			rv = rv.Elem()
459			if rv.IsNil() {
460				return 0
461			}
462			return si.oneofWrappersByType[rv.Type().Elem()]
463		}
464	}
465	return oi
466}
467