1// Copyright 2019 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package impl
6
7import (
8	"fmt"
9	"reflect"
10
11	"google.golang.org/protobuf/encoding/protowire"
12	"google.golang.org/protobuf/internal/strs"
13	pref "google.golang.org/protobuf/reflect/protoreflect"
14)
15
16// pointerCoderFuncs is a set of pointer encoding functions.
17type pointerCoderFuncs struct {
18	mi        *MessageInfo
19	size      func(p pointer, f *coderFieldInfo, opts marshalOptions) int
20	marshal   func(b []byte, p pointer, f *coderFieldInfo, opts marshalOptions) ([]byte, error)
21	unmarshal func(b []byte, p pointer, wtyp protowire.Type, f *coderFieldInfo, opts unmarshalOptions) (unmarshalOutput, error)
22	isInit    func(p pointer, f *coderFieldInfo) error
23	merge     func(dst, src pointer, f *coderFieldInfo, opts mergeOptions)
24}
25
26// valueCoderFuncs is a set of protoreflect.Value encoding functions.
27type valueCoderFuncs struct {
28	size      func(v pref.Value, tagsize int, opts marshalOptions) int
29	marshal   func(b []byte, v pref.Value, wiretag uint64, opts marshalOptions) ([]byte, error)
30	unmarshal func(b []byte, v pref.Value, num protowire.Number, wtyp protowire.Type, opts unmarshalOptions) (pref.Value, unmarshalOutput, error)
31	isInit    func(v pref.Value) error
32	merge     func(dst, src pref.Value, opts mergeOptions) pref.Value
33}
34
35// fieldCoder returns pointer functions for a field, used for operating on
36// struct fields.
37func fieldCoder(fd pref.FieldDescriptor, ft reflect.Type) (*MessageInfo, pointerCoderFuncs) {
38	switch {
39	case fd.IsMap():
40		return encoderFuncsForMap(fd, ft)
41	case fd.Cardinality() == pref.Repeated && !fd.IsPacked():
42		// Repeated fields (not packed).
43		if ft.Kind() != reflect.Slice {
44			break
45		}
46		ft := ft.Elem()
47		switch fd.Kind() {
48		case pref.BoolKind:
49			if ft.Kind() == reflect.Bool {
50				return nil, coderBoolSlice
51			}
52		case pref.EnumKind:
53			if ft.Kind() == reflect.Int32 {
54				return nil, coderEnumSlice
55			}
56		case pref.Int32Kind:
57			if ft.Kind() == reflect.Int32 {
58				return nil, coderInt32Slice
59			}
60		case pref.Sint32Kind:
61			if ft.Kind() == reflect.Int32 {
62				return nil, coderSint32Slice
63			}
64		case pref.Uint32Kind:
65			if ft.Kind() == reflect.Uint32 {
66				return nil, coderUint32Slice
67			}
68		case pref.Int64Kind:
69			if ft.Kind() == reflect.Int64 {
70				return nil, coderInt64Slice
71			}
72		case pref.Sint64Kind:
73			if ft.Kind() == reflect.Int64 {
74				return nil, coderSint64Slice
75			}
76		case pref.Uint64Kind:
77			if ft.Kind() == reflect.Uint64 {
78				return nil, coderUint64Slice
79			}
80		case pref.Sfixed32Kind:
81			if ft.Kind() == reflect.Int32 {
82				return nil, coderSfixed32Slice
83			}
84		case pref.Fixed32Kind:
85			if ft.Kind() == reflect.Uint32 {
86				return nil, coderFixed32Slice
87			}
88		case pref.FloatKind:
89			if ft.Kind() == reflect.Float32 {
90				return nil, coderFloatSlice
91			}
92		case pref.Sfixed64Kind:
93			if ft.Kind() == reflect.Int64 {
94				return nil, coderSfixed64Slice
95			}
96		case pref.Fixed64Kind:
97			if ft.Kind() == reflect.Uint64 {
98				return nil, coderFixed64Slice
99			}
100		case pref.DoubleKind:
101			if ft.Kind() == reflect.Float64 {
102				return nil, coderDoubleSlice
103			}
104		case pref.StringKind:
105			if ft.Kind() == reflect.String && strs.EnforceUTF8(fd) {
106				return nil, coderStringSliceValidateUTF8
107			}
108			if ft.Kind() == reflect.String {
109				return nil, coderStringSlice
110			}
111			if ft.Kind() == reflect.Slice && ft.Elem().Kind() == reflect.Uint8 && strs.EnforceUTF8(fd) {
112				return nil, coderBytesSliceValidateUTF8
113			}
114			if ft.Kind() == reflect.Slice && ft.Elem().Kind() == reflect.Uint8 {
115				return nil, coderBytesSlice
116			}
117		case pref.BytesKind:
118			if ft.Kind() == reflect.String {
119				return nil, coderStringSlice
120			}
121			if ft.Kind() == reflect.Slice && ft.Elem().Kind() == reflect.Uint8 {
122				return nil, coderBytesSlice
123			}
124		case pref.MessageKind:
125			return getMessageInfo(ft), makeMessageSliceFieldCoder(fd, ft)
126		case pref.GroupKind:
127			return getMessageInfo(ft), makeGroupSliceFieldCoder(fd, ft)
128		}
129	case fd.Cardinality() == pref.Repeated && fd.IsPacked():
130		// Packed repeated fields.
131		//
132		// Only repeated fields of primitive numeric types
133		// (Varint, Fixed32, or Fixed64 wire type) can be packed.
134		if ft.Kind() != reflect.Slice {
135			break
136		}
137		ft := ft.Elem()
138		switch fd.Kind() {
139		case pref.BoolKind:
140			if ft.Kind() == reflect.Bool {
141				return nil, coderBoolPackedSlice
142			}
143		case pref.EnumKind:
144			if ft.Kind() == reflect.Int32 {
145				return nil, coderEnumPackedSlice
146			}
147		case pref.Int32Kind:
148			if ft.Kind() == reflect.Int32 {
149				return nil, coderInt32PackedSlice
150			}
151		case pref.Sint32Kind:
152			if ft.Kind() == reflect.Int32 {
153				return nil, coderSint32PackedSlice
154			}
155		case pref.Uint32Kind:
156			if ft.Kind() == reflect.Uint32 {
157				return nil, coderUint32PackedSlice
158			}
159		case pref.Int64Kind:
160			if ft.Kind() == reflect.Int64 {
161				return nil, coderInt64PackedSlice
162			}
163		case pref.Sint64Kind:
164			if ft.Kind() == reflect.Int64 {
165				return nil, coderSint64PackedSlice
166			}
167		case pref.Uint64Kind:
168			if ft.Kind() == reflect.Uint64 {
169				return nil, coderUint64PackedSlice
170			}
171		case pref.Sfixed32Kind:
172			if ft.Kind() == reflect.Int32 {
173				return nil, coderSfixed32PackedSlice
174			}
175		case pref.Fixed32Kind:
176			if ft.Kind() == reflect.Uint32 {
177				return nil, coderFixed32PackedSlice
178			}
179		case pref.FloatKind:
180			if ft.Kind() == reflect.Float32 {
181				return nil, coderFloatPackedSlice
182			}
183		case pref.Sfixed64Kind:
184			if ft.Kind() == reflect.Int64 {
185				return nil, coderSfixed64PackedSlice
186			}
187		case pref.Fixed64Kind:
188			if ft.Kind() == reflect.Uint64 {
189				return nil, coderFixed64PackedSlice
190			}
191		case pref.DoubleKind:
192			if ft.Kind() == reflect.Float64 {
193				return nil, coderDoublePackedSlice
194			}
195		}
196	case fd.Kind() == pref.MessageKind:
197		return getMessageInfo(ft), makeMessageFieldCoder(fd, ft)
198	case fd.Kind() == pref.GroupKind:
199		return getMessageInfo(ft), makeGroupFieldCoder(fd, ft)
200	case fd.Syntax() == pref.Proto3 && fd.ContainingOneof() == nil:
201		// Populated oneof fields always encode even if set to the zero value,
202		// which normally are not encoded in proto3.
203		switch fd.Kind() {
204		case pref.BoolKind:
205			if ft.Kind() == reflect.Bool {
206				return nil, coderBoolNoZero
207			}
208		case pref.EnumKind:
209			if ft.Kind() == reflect.Int32 {
210				return nil, coderEnumNoZero
211			}
212		case pref.Int32Kind:
213			if ft.Kind() == reflect.Int32 {
214				return nil, coderInt32NoZero
215			}
216		case pref.Sint32Kind:
217			if ft.Kind() == reflect.Int32 {
218				return nil, coderSint32NoZero
219			}
220		case pref.Uint32Kind:
221			if ft.Kind() == reflect.Uint32 {
222				return nil, coderUint32NoZero
223			}
224		case pref.Int64Kind:
225			if ft.Kind() == reflect.Int64 {
226				return nil, coderInt64NoZero
227			}
228		case pref.Sint64Kind:
229			if ft.Kind() == reflect.Int64 {
230				return nil, coderSint64NoZero
231			}
232		case pref.Uint64Kind:
233			if ft.Kind() == reflect.Uint64 {
234				return nil, coderUint64NoZero
235			}
236		case pref.Sfixed32Kind:
237			if ft.Kind() == reflect.Int32 {
238				return nil, coderSfixed32NoZero
239			}
240		case pref.Fixed32Kind:
241			if ft.Kind() == reflect.Uint32 {
242				return nil, coderFixed32NoZero
243			}
244		case pref.FloatKind:
245			if ft.Kind() == reflect.Float32 {
246				return nil, coderFloatNoZero
247			}
248		case pref.Sfixed64Kind:
249			if ft.Kind() == reflect.Int64 {
250				return nil, coderSfixed64NoZero
251			}
252		case pref.Fixed64Kind:
253			if ft.Kind() == reflect.Uint64 {
254				return nil, coderFixed64NoZero
255			}
256		case pref.DoubleKind:
257			if ft.Kind() == reflect.Float64 {
258				return nil, coderDoubleNoZero
259			}
260		case pref.StringKind:
261			if ft.Kind() == reflect.String && strs.EnforceUTF8(fd) {
262				return nil, coderStringNoZeroValidateUTF8
263			}
264			if ft.Kind() == reflect.String {
265				return nil, coderStringNoZero
266			}
267			if ft.Kind() == reflect.Slice && ft.Elem().Kind() == reflect.Uint8 && strs.EnforceUTF8(fd) {
268				return nil, coderBytesNoZeroValidateUTF8
269			}
270			if ft.Kind() == reflect.Slice && ft.Elem().Kind() == reflect.Uint8 {
271				return nil, coderBytesNoZero
272			}
273		case pref.BytesKind:
274			if ft.Kind() == reflect.String {
275				return nil, coderStringNoZero
276			}
277			if ft.Kind() == reflect.Slice && ft.Elem().Kind() == reflect.Uint8 {
278				return nil, coderBytesNoZero
279			}
280		}
281	case ft.Kind() == reflect.Ptr:
282		ft := ft.Elem()
283		switch fd.Kind() {
284		case pref.BoolKind:
285			if ft.Kind() == reflect.Bool {
286				return nil, coderBoolPtr
287			}
288		case pref.EnumKind:
289			if ft.Kind() == reflect.Int32 {
290				return nil, coderEnumPtr
291			}
292		case pref.Int32Kind:
293			if ft.Kind() == reflect.Int32 {
294				return nil, coderInt32Ptr
295			}
296		case pref.Sint32Kind:
297			if ft.Kind() == reflect.Int32 {
298				return nil, coderSint32Ptr
299			}
300		case pref.Uint32Kind:
301			if ft.Kind() == reflect.Uint32 {
302				return nil, coderUint32Ptr
303			}
304		case pref.Int64Kind:
305			if ft.Kind() == reflect.Int64 {
306				return nil, coderInt64Ptr
307			}
308		case pref.Sint64Kind:
309			if ft.Kind() == reflect.Int64 {
310				return nil, coderSint64Ptr
311			}
312		case pref.Uint64Kind:
313			if ft.Kind() == reflect.Uint64 {
314				return nil, coderUint64Ptr
315			}
316		case pref.Sfixed32Kind:
317			if ft.Kind() == reflect.Int32 {
318				return nil, coderSfixed32Ptr
319			}
320		case pref.Fixed32Kind:
321			if ft.Kind() == reflect.Uint32 {
322				return nil, coderFixed32Ptr
323			}
324		case pref.FloatKind:
325			if ft.Kind() == reflect.Float32 {
326				return nil, coderFloatPtr
327			}
328		case pref.Sfixed64Kind:
329			if ft.Kind() == reflect.Int64 {
330				return nil, coderSfixed64Ptr
331			}
332		case pref.Fixed64Kind:
333			if ft.Kind() == reflect.Uint64 {
334				return nil, coderFixed64Ptr
335			}
336		case pref.DoubleKind:
337			if ft.Kind() == reflect.Float64 {
338				return nil, coderDoublePtr
339			}
340		case pref.StringKind:
341			if ft.Kind() == reflect.String && strs.EnforceUTF8(fd) {
342				return nil, coderStringPtrValidateUTF8
343			}
344			if ft.Kind() == reflect.String {
345				return nil, coderStringPtr
346			}
347		case pref.BytesKind:
348			if ft.Kind() == reflect.String {
349				return nil, coderStringPtr
350			}
351		}
352	default:
353		switch fd.Kind() {
354		case pref.BoolKind:
355			if ft.Kind() == reflect.Bool {
356				return nil, coderBool
357			}
358		case pref.EnumKind:
359			if ft.Kind() == reflect.Int32 {
360				return nil, coderEnum
361			}
362		case pref.Int32Kind:
363			if ft.Kind() == reflect.Int32 {
364				return nil, coderInt32
365			}
366		case pref.Sint32Kind:
367			if ft.Kind() == reflect.Int32 {
368				return nil, coderSint32
369			}
370		case pref.Uint32Kind:
371			if ft.Kind() == reflect.Uint32 {
372				return nil, coderUint32
373			}
374		case pref.Int64Kind:
375			if ft.Kind() == reflect.Int64 {
376				return nil, coderInt64
377			}
378		case pref.Sint64Kind:
379			if ft.Kind() == reflect.Int64 {
380				return nil, coderSint64
381			}
382		case pref.Uint64Kind:
383			if ft.Kind() == reflect.Uint64 {
384				return nil, coderUint64
385			}
386		case pref.Sfixed32Kind:
387			if ft.Kind() == reflect.Int32 {
388				return nil, coderSfixed32
389			}
390		case pref.Fixed32Kind:
391			if ft.Kind() == reflect.Uint32 {
392				return nil, coderFixed32
393			}
394		case pref.FloatKind:
395			if ft.Kind() == reflect.Float32 {
396				return nil, coderFloat
397			}
398		case pref.Sfixed64Kind:
399			if ft.Kind() == reflect.Int64 {
400				return nil, coderSfixed64
401			}
402		case pref.Fixed64Kind:
403			if ft.Kind() == reflect.Uint64 {
404				return nil, coderFixed64
405			}
406		case pref.DoubleKind:
407			if ft.Kind() == reflect.Float64 {
408				return nil, coderDouble
409			}
410		case pref.StringKind:
411			if ft.Kind() == reflect.String && strs.EnforceUTF8(fd) {
412				return nil, coderStringValidateUTF8
413			}
414			if ft.Kind() == reflect.String {
415				return nil, coderString
416			}
417			if ft.Kind() == reflect.Slice && ft.Elem().Kind() == reflect.Uint8 && strs.EnforceUTF8(fd) {
418				return nil, coderBytesValidateUTF8
419			}
420			if ft.Kind() == reflect.Slice && ft.Elem().Kind() == reflect.Uint8 {
421				return nil, coderBytes
422			}
423		case pref.BytesKind:
424			if ft.Kind() == reflect.String {
425				return nil, coderString
426			}
427			if ft.Kind() == reflect.Slice && ft.Elem().Kind() == reflect.Uint8 {
428				return nil, coderBytes
429			}
430		}
431	}
432	panic(fmt.Sprintf("invalid type: no encoder for %v %v %v/%v", fd.FullName(), fd.Cardinality(), fd.Kind(), ft))
433}
434
435// encoderFuncsForValue returns value functions for a field, used for
436// extension values and map encoding.
437func encoderFuncsForValue(fd pref.FieldDescriptor) valueCoderFuncs {
438	switch {
439	case fd.Cardinality() == pref.Repeated && !fd.IsPacked():
440		switch fd.Kind() {
441		case pref.BoolKind:
442			return coderBoolSliceValue
443		case pref.EnumKind:
444			return coderEnumSliceValue
445		case pref.Int32Kind:
446			return coderInt32SliceValue
447		case pref.Sint32Kind:
448			return coderSint32SliceValue
449		case pref.Uint32Kind:
450			return coderUint32SliceValue
451		case pref.Int64Kind:
452			return coderInt64SliceValue
453		case pref.Sint64Kind:
454			return coderSint64SliceValue
455		case pref.Uint64Kind:
456			return coderUint64SliceValue
457		case pref.Sfixed32Kind:
458			return coderSfixed32SliceValue
459		case pref.Fixed32Kind:
460			return coderFixed32SliceValue
461		case pref.FloatKind:
462			return coderFloatSliceValue
463		case pref.Sfixed64Kind:
464			return coderSfixed64SliceValue
465		case pref.Fixed64Kind:
466			return coderFixed64SliceValue
467		case pref.DoubleKind:
468			return coderDoubleSliceValue
469		case pref.StringKind:
470			// We don't have a UTF-8 validating coder for repeated string fields.
471			// Value coders are used for extensions and maps.
472			// Extensions are never proto3, and maps never contain lists.
473			return coderStringSliceValue
474		case pref.BytesKind:
475			return coderBytesSliceValue
476		case pref.MessageKind:
477			return coderMessageSliceValue
478		case pref.GroupKind:
479			return coderGroupSliceValue
480		}
481	case fd.Cardinality() == pref.Repeated && fd.IsPacked():
482		switch fd.Kind() {
483		case pref.BoolKind:
484			return coderBoolPackedSliceValue
485		case pref.EnumKind:
486			return coderEnumPackedSliceValue
487		case pref.Int32Kind:
488			return coderInt32PackedSliceValue
489		case pref.Sint32Kind:
490			return coderSint32PackedSliceValue
491		case pref.Uint32Kind:
492			return coderUint32PackedSliceValue
493		case pref.Int64Kind:
494			return coderInt64PackedSliceValue
495		case pref.Sint64Kind:
496			return coderSint64PackedSliceValue
497		case pref.Uint64Kind:
498			return coderUint64PackedSliceValue
499		case pref.Sfixed32Kind:
500			return coderSfixed32PackedSliceValue
501		case pref.Fixed32Kind:
502			return coderFixed32PackedSliceValue
503		case pref.FloatKind:
504			return coderFloatPackedSliceValue
505		case pref.Sfixed64Kind:
506			return coderSfixed64PackedSliceValue
507		case pref.Fixed64Kind:
508			return coderFixed64PackedSliceValue
509		case pref.DoubleKind:
510			return coderDoublePackedSliceValue
511		}
512	default:
513		switch fd.Kind() {
514		default:
515		case pref.BoolKind:
516			return coderBoolValue
517		case pref.EnumKind:
518			return coderEnumValue
519		case pref.Int32Kind:
520			return coderInt32Value
521		case pref.Sint32Kind:
522			return coderSint32Value
523		case pref.Uint32Kind:
524			return coderUint32Value
525		case pref.Int64Kind:
526			return coderInt64Value
527		case pref.Sint64Kind:
528			return coderSint64Value
529		case pref.Uint64Kind:
530			return coderUint64Value
531		case pref.Sfixed32Kind:
532			return coderSfixed32Value
533		case pref.Fixed32Kind:
534			return coderFixed32Value
535		case pref.FloatKind:
536			return coderFloatValue
537		case pref.Sfixed64Kind:
538			return coderSfixed64Value
539		case pref.Fixed64Kind:
540			return coderFixed64Value
541		case pref.DoubleKind:
542			return coderDoubleValue
543		case pref.StringKind:
544			if strs.EnforceUTF8(fd) {
545				return coderStringValueValidateUTF8
546			}
547			return coderStringValue
548		case pref.BytesKind:
549			return coderBytesValue
550		case pref.MessageKind:
551			return coderMessageValue
552		case pref.GroupKind:
553			return coderGroupValue
554		}
555	}
556	panic(fmt.Sprintf("invalid field: no encoder for %v %v %v", fd.FullName(), fd.Cardinality(), fd.Kind()))
557}
558