1// Copyright 2018 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package impl
6
7import (
8	"fmt"
9	"reflect"
10
11	pref "google.golang.org/protobuf/reflect/protoreflect"
12)
13
14// unwrapper unwraps the value to the underlying value.
15// This is implemented by List and Map.
16type unwrapper interface {
17	protoUnwrap() interface{}
18}
19
20// A Converter coverts to/from Go reflect.Value types and protobuf protoreflect.Value types.
21type Converter interface {
22	// PBValueOf converts a reflect.Value to a protoreflect.Value.
23	PBValueOf(reflect.Value) pref.Value
24
25	// GoValueOf converts a protoreflect.Value to a reflect.Value.
26	GoValueOf(pref.Value) reflect.Value
27
28	// IsValidPB returns whether a protoreflect.Value is compatible with this type.
29	IsValidPB(pref.Value) bool
30
31	// IsValidGo returns whether a reflect.Value is compatible with this type.
32	IsValidGo(reflect.Value) bool
33
34	// New returns a new field value.
35	// For scalars, it returns the default value of the field.
36	// For composite types, it returns a new mutable value.
37	New() pref.Value
38
39	// Zero returns a new field value.
40	// For scalars, it returns the default value of the field.
41	// For composite types, it returns an immutable, empty value.
42	Zero() pref.Value
43}
44
45// NewConverter matches a Go type with a protobuf field and returns a Converter
46// that converts between the two. Enums must be a named int32 kind that
47// implements protoreflect.Enum, and messages must be pointer to a named
48// struct type that implements protoreflect.ProtoMessage.
49//
50// This matcher deliberately supports a wider range of Go types than what
51// protoc-gen-go historically generated to be able to automatically wrap some
52// v1 messages generated by other forks of protoc-gen-go.
53func NewConverter(t reflect.Type, fd pref.FieldDescriptor) Converter {
54	switch {
55	case fd.IsList():
56		return newListConverter(t, fd)
57	case fd.IsMap():
58		return newMapConverter(t, fd)
59	default:
60		return newSingularConverter(t, fd)
61	}
62	panic(fmt.Sprintf("invalid Go type %v for field %v", t, fd.FullName()))
63}
64
65var (
66	boolType    = reflect.TypeOf(bool(false))
67	int32Type   = reflect.TypeOf(int32(0))
68	int64Type   = reflect.TypeOf(int64(0))
69	uint32Type  = reflect.TypeOf(uint32(0))
70	uint64Type  = reflect.TypeOf(uint64(0))
71	float32Type = reflect.TypeOf(float32(0))
72	float64Type = reflect.TypeOf(float64(0))
73	stringType  = reflect.TypeOf(string(""))
74	bytesType   = reflect.TypeOf([]byte(nil))
75	byteType    = reflect.TypeOf(byte(0))
76)
77
78var (
79	boolZero    = pref.ValueOfBool(false)
80	int32Zero   = pref.ValueOfInt32(0)
81	int64Zero   = pref.ValueOfInt64(0)
82	uint32Zero  = pref.ValueOfUint32(0)
83	uint64Zero  = pref.ValueOfUint64(0)
84	float32Zero = pref.ValueOfFloat32(0)
85	float64Zero = pref.ValueOfFloat64(0)
86	stringZero  = pref.ValueOfString("")
87	bytesZero   = pref.ValueOfBytes(nil)
88)
89
90func newSingularConverter(t reflect.Type, fd pref.FieldDescriptor) Converter {
91	defVal := func(fd pref.FieldDescriptor, zero pref.Value) pref.Value {
92		if fd.Cardinality() == pref.Repeated {
93			// Default isn't defined for repeated fields.
94			return zero
95		}
96		return fd.Default()
97	}
98	switch fd.Kind() {
99	case pref.BoolKind:
100		if t.Kind() == reflect.Bool {
101			return &boolConverter{t, defVal(fd, boolZero)}
102		}
103	case pref.Int32Kind, pref.Sint32Kind, pref.Sfixed32Kind:
104		if t.Kind() == reflect.Int32 {
105			return &int32Converter{t, defVal(fd, int32Zero)}
106		}
107	case pref.Int64Kind, pref.Sint64Kind, pref.Sfixed64Kind:
108		if t.Kind() == reflect.Int64 {
109			return &int64Converter{t, defVal(fd, int64Zero)}
110		}
111	case pref.Uint32Kind, pref.Fixed32Kind:
112		if t.Kind() == reflect.Uint32 {
113			return &uint32Converter{t, defVal(fd, uint32Zero)}
114		}
115	case pref.Uint64Kind, pref.Fixed64Kind:
116		if t.Kind() == reflect.Uint64 {
117			return &uint64Converter{t, defVal(fd, uint64Zero)}
118		}
119	case pref.FloatKind:
120		if t.Kind() == reflect.Float32 {
121			return &float32Converter{t, defVal(fd, float32Zero)}
122		}
123	case pref.DoubleKind:
124		if t.Kind() == reflect.Float64 {
125			return &float64Converter{t, defVal(fd, float64Zero)}
126		}
127	case pref.StringKind:
128		if t.Kind() == reflect.String || (t.Kind() == reflect.Slice && t.Elem() == byteType) {
129			return &stringConverter{t, defVal(fd, stringZero)}
130		}
131	case pref.BytesKind:
132		if t.Kind() == reflect.String || (t.Kind() == reflect.Slice && t.Elem() == byteType) {
133			return &bytesConverter{t, defVal(fd, bytesZero)}
134		}
135	case pref.EnumKind:
136		// Handle enums, which must be a named int32 type.
137		if t.Kind() == reflect.Int32 {
138			return newEnumConverter(t, fd)
139		}
140	case pref.MessageKind, pref.GroupKind:
141		return newMessageConverter(t)
142	}
143	panic(fmt.Sprintf("invalid Go type %v for field %v", t, fd.FullName()))
144}
145
146type boolConverter struct {
147	goType reflect.Type
148	def    pref.Value
149}
150
151func (c *boolConverter) PBValueOf(v reflect.Value) pref.Value {
152	if v.Type() != c.goType {
153		panic(fmt.Sprintf("invalid type: got %v, want %v", v.Type(), c.goType))
154	}
155	return pref.ValueOfBool(v.Bool())
156}
157func (c *boolConverter) GoValueOf(v pref.Value) reflect.Value {
158	return reflect.ValueOf(v.Bool()).Convert(c.goType)
159}
160func (c *boolConverter) IsValidPB(v pref.Value) bool {
161	_, ok := v.Interface().(bool)
162	return ok
163}
164func (c *boolConverter) IsValidGo(v reflect.Value) bool {
165	return v.IsValid() && v.Type() == c.goType
166}
167func (c *boolConverter) New() pref.Value  { return c.def }
168func (c *boolConverter) Zero() pref.Value { return c.def }
169
170type int32Converter struct {
171	goType reflect.Type
172	def    pref.Value
173}
174
175func (c *int32Converter) PBValueOf(v reflect.Value) pref.Value {
176	if v.Type() != c.goType {
177		panic(fmt.Sprintf("invalid type: got %v, want %v", v.Type(), c.goType))
178	}
179	return pref.ValueOfInt32(int32(v.Int()))
180}
181func (c *int32Converter) GoValueOf(v pref.Value) reflect.Value {
182	return reflect.ValueOf(int32(v.Int())).Convert(c.goType)
183}
184func (c *int32Converter) IsValidPB(v pref.Value) bool {
185	_, ok := v.Interface().(int32)
186	return ok
187}
188func (c *int32Converter) IsValidGo(v reflect.Value) bool {
189	return v.IsValid() && v.Type() == c.goType
190}
191func (c *int32Converter) New() pref.Value  { return c.def }
192func (c *int32Converter) Zero() pref.Value { return c.def }
193
194type int64Converter struct {
195	goType reflect.Type
196	def    pref.Value
197}
198
199func (c *int64Converter) PBValueOf(v reflect.Value) pref.Value {
200	if v.Type() != c.goType {
201		panic(fmt.Sprintf("invalid type: got %v, want %v", v.Type(), c.goType))
202	}
203	return pref.ValueOfInt64(int64(v.Int()))
204}
205func (c *int64Converter) GoValueOf(v pref.Value) reflect.Value {
206	return reflect.ValueOf(int64(v.Int())).Convert(c.goType)
207}
208func (c *int64Converter) IsValidPB(v pref.Value) bool {
209	_, ok := v.Interface().(int64)
210	return ok
211}
212func (c *int64Converter) IsValidGo(v reflect.Value) bool {
213	return v.IsValid() && v.Type() == c.goType
214}
215func (c *int64Converter) New() pref.Value  { return c.def }
216func (c *int64Converter) Zero() pref.Value { return c.def }
217
218type uint32Converter struct {
219	goType reflect.Type
220	def    pref.Value
221}
222
223func (c *uint32Converter) PBValueOf(v reflect.Value) pref.Value {
224	if v.Type() != c.goType {
225		panic(fmt.Sprintf("invalid type: got %v, want %v", v.Type(), c.goType))
226	}
227	return pref.ValueOfUint32(uint32(v.Uint()))
228}
229func (c *uint32Converter) GoValueOf(v pref.Value) reflect.Value {
230	return reflect.ValueOf(uint32(v.Uint())).Convert(c.goType)
231}
232func (c *uint32Converter) IsValidPB(v pref.Value) bool {
233	_, ok := v.Interface().(uint32)
234	return ok
235}
236func (c *uint32Converter) IsValidGo(v reflect.Value) bool {
237	return v.IsValid() && v.Type() == c.goType
238}
239func (c *uint32Converter) New() pref.Value  { return c.def }
240func (c *uint32Converter) Zero() pref.Value { return c.def }
241
242type uint64Converter struct {
243	goType reflect.Type
244	def    pref.Value
245}
246
247func (c *uint64Converter) PBValueOf(v reflect.Value) pref.Value {
248	if v.Type() != c.goType {
249		panic(fmt.Sprintf("invalid type: got %v, want %v", v.Type(), c.goType))
250	}
251	return pref.ValueOfUint64(uint64(v.Uint()))
252}
253func (c *uint64Converter) GoValueOf(v pref.Value) reflect.Value {
254	return reflect.ValueOf(uint64(v.Uint())).Convert(c.goType)
255}
256func (c *uint64Converter) IsValidPB(v pref.Value) bool {
257	_, ok := v.Interface().(uint64)
258	return ok
259}
260func (c *uint64Converter) IsValidGo(v reflect.Value) bool {
261	return v.IsValid() && v.Type() == c.goType
262}
263func (c *uint64Converter) New() pref.Value  { return c.def }
264func (c *uint64Converter) Zero() pref.Value { return c.def }
265
266type float32Converter struct {
267	goType reflect.Type
268	def    pref.Value
269}
270
271func (c *float32Converter) PBValueOf(v reflect.Value) pref.Value {
272	if v.Type() != c.goType {
273		panic(fmt.Sprintf("invalid type: got %v, want %v", v.Type(), c.goType))
274	}
275	return pref.ValueOfFloat32(float32(v.Float()))
276}
277func (c *float32Converter) GoValueOf(v pref.Value) reflect.Value {
278	return reflect.ValueOf(float32(v.Float())).Convert(c.goType)
279}
280func (c *float32Converter) IsValidPB(v pref.Value) bool {
281	_, ok := v.Interface().(float32)
282	return ok
283}
284func (c *float32Converter) IsValidGo(v reflect.Value) bool {
285	return v.IsValid() && v.Type() == c.goType
286}
287func (c *float32Converter) New() pref.Value  { return c.def }
288func (c *float32Converter) Zero() pref.Value { return c.def }
289
290type float64Converter struct {
291	goType reflect.Type
292	def    pref.Value
293}
294
295func (c *float64Converter) PBValueOf(v reflect.Value) pref.Value {
296	if v.Type() != c.goType {
297		panic(fmt.Sprintf("invalid type: got %v, want %v", v.Type(), c.goType))
298	}
299	return pref.ValueOfFloat64(float64(v.Float()))
300}
301func (c *float64Converter) GoValueOf(v pref.Value) reflect.Value {
302	return reflect.ValueOf(float64(v.Float())).Convert(c.goType)
303}
304func (c *float64Converter) IsValidPB(v pref.Value) bool {
305	_, ok := v.Interface().(float64)
306	return ok
307}
308func (c *float64Converter) IsValidGo(v reflect.Value) bool {
309	return v.IsValid() && v.Type() == c.goType
310}
311func (c *float64Converter) New() pref.Value  { return c.def }
312func (c *float64Converter) Zero() pref.Value { return c.def }
313
314type stringConverter struct {
315	goType reflect.Type
316	def    pref.Value
317}
318
319func (c *stringConverter) PBValueOf(v reflect.Value) pref.Value {
320	if v.Type() != c.goType {
321		panic(fmt.Sprintf("invalid type: got %v, want %v", v.Type(), c.goType))
322	}
323	return pref.ValueOfString(v.Convert(stringType).String())
324}
325func (c *stringConverter) GoValueOf(v pref.Value) reflect.Value {
326	// pref.Value.String never panics, so we go through an interface
327	// conversion here to check the type.
328	s := v.Interface().(string)
329	if c.goType.Kind() == reflect.Slice && s == "" {
330		return reflect.Zero(c.goType) // ensure empty string is []byte(nil)
331	}
332	return reflect.ValueOf(s).Convert(c.goType)
333}
334func (c *stringConverter) IsValidPB(v pref.Value) bool {
335	_, ok := v.Interface().(string)
336	return ok
337}
338func (c *stringConverter) IsValidGo(v reflect.Value) bool {
339	return v.IsValid() && v.Type() == c.goType
340}
341func (c *stringConverter) New() pref.Value  { return c.def }
342func (c *stringConverter) Zero() pref.Value { return c.def }
343
344type bytesConverter struct {
345	goType reflect.Type
346	def    pref.Value
347}
348
349func (c *bytesConverter) PBValueOf(v reflect.Value) pref.Value {
350	if v.Type() != c.goType {
351		panic(fmt.Sprintf("invalid type: got %v, want %v", v.Type(), c.goType))
352	}
353	if c.goType.Kind() == reflect.String && v.Len() == 0 {
354		return pref.ValueOfBytes(nil) // ensure empty string is []byte(nil)
355	}
356	return pref.ValueOfBytes(v.Convert(bytesType).Bytes())
357}
358func (c *bytesConverter) GoValueOf(v pref.Value) reflect.Value {
359	return reflect.ValueOf(v.Bytes()).Convert(c.goType)
360}
361func (c *bytesConverter) IsValidPB(v pref.Value) bool {
362	_, ok := v.Interface().([]byte)
363	return ok
364}
365func (c *bytesConverter) IsValidGo(v reflect.Value) bool {
366	return v.IsValid() && v.Type() == c.goType
367}
368func (c *bytesConverter) New() pref.Value  { return c.def }
369func (c *bytesConverter) Zero() pref.Value { return c.def }
370
371type enumConverter struct {
372	goType reflect.Type
373	def    pref.Value
374}
375
376func newEnumConverter(goType reflect.Type, fd pref.FieldDescriptor) Converter {
377	var def pref.Value
378	if fd.Cardinality() == pref.Repeated {
379		def = pref.ValueOfEnum(fd.Enum().Values().Get(0).Number())
380	} else {
381		def = fd.Default()
382	}
383	return &enumConverter{goType, def}
384}
385
386func (c *enumConverter) PBValueOf(v reflect.Value) pref.Value {
387	if v.Type() != c.goType {
388		panic(fmt.Sprintf("invalid type: got %v, want %v", v.Type(), c.goType))
389	}
390	return pref.ValueOfEnum(pref.EnumNumber(v.Int()))
391}
392
393func (c *enumConverter) GoValueOf(v pref.Value) reflect.Value {
394	return reflect.ValueOf(v.Enum()).Convert(c.goType)
395}
396
397func (c *enumConverter) IsValidPB(v pref.Value) bool {
398	_, ok := v.Interface().(pref.EnumNumber)
399	return ok
400}
401
402func (c *enumConverter) IsValidGo(v reflect.Value) bool {
403	return v.IsValid() && v.Type() == c.goType
404}
405
406func (c *enumConverter) New() pref.Value {
407	return c.def
408}
409
410func (c *enumConverter) Zero() pref.Value {
411	return c.def
412}
413
414type messageConverter struct {
415	goType reflect.Type
416}
417
418func newMessageConverter(goType reflect.Type) Converter {
419	return &messageConverter{goType}
420}
421
422func (c *messageConverter) PBValueOf(v reflect.Value) pref.Value {
423	if v.Type() != c.goType {
424		panic(fmt.Sprintf("invalid type: got %v, want %v", v.Type(), c.goType))
425	}
426	if m, ok := v.Interface().(pref.ProtoMessage); ok {
427		return pref.ValueOfMessage(m.ProtoReflect())
428	}
429	return pref.ValueOfMessage(legacyWrapMessage(v))
430}
431
432func (c *messageConverter) GoValueOf(v pref.Value) reflect.Value {
433	m := v.Message()
434	var rv reflect.Value
435	if u, ok := m.(unwrapper); ok {
436		rv = reflect.ValueOf(u.protoUnwrap())
437	} else {
438		rv = reflect.ValueOf(m.Interface())
439	}
440	if rv.Type() != c.goType {
441		panic(fmt.Sprintf("invalid type: got %v, want %v", rv.Type(), c.goType))
442	}
443	return rv
444}
445
446func (c *messageConverter) IsValidPB(v pref.Value) bool {
447	m := v.Message()
448	var rv reflect.Value
449	if u, ok := m.(unwrapper); ok {
450		rv = reflect.ValueOf(u.protoUnwrap())
451	} else {
452		rv = reflect.ValueOf(m.Interface())
453	}
454	return rv.Type() == c.goType
455}
456
457func (c *messageConverter) IsValidGo(v reflect.Value) bool {
458	return v.IsValid() && v.Type() == c.goType
459}
460
461func (c *messageConverter) New() pref.Value {
462	return c.PBValueOf(reflect.New(c.goType.Elem()))
463}
464
465func (c *messageConverter) Zero() pref.Value {
466	return c.PBValueOf(reflect.Zero(c.goType))
467}
468