1// Copyright 2018 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package main
6
7import (
8	"strings"
9	"text/template"
10)
11
12type WireType string
13
14const (
15	WireVarint  WireType = "Varint"
16	WireFixed32 WireType = "Fixed32"
17	WireFixed64 WireType = "Fixed64"
18	WireBytes   WireType = "Bytes"
19	WireGroup   WireType = "Group"
20)
21
22func (w WireType) Expr() Expr {
23	if w == WireGroup {
24		return "protowire.StartGroupType"
25	}
26	return "protowire." + Expr(w) + "Type"
27}
28
29func (w WireType) Packable() bool {
30	return w == WireVarint || w == WireFixed32 || w == WireFixed64
31}
32
33func (w WireType) ConstSize() bool {
34	return w == WireFixed32 || w == WireFixed64
35}
36
37type GoType string
38
39var GoTypes = []GoType{
40	GoBool,
41	GoInt32,
42	GoUint32,
43	GoInt64,
44	GoUint64,
45	GoFloat32,
46	GoFloat64,
47	GoString,
48	GoBytes,
49}
50
51const (
52	GoBool    = "bool"
53	GoInt32   = "int32"
54	GoUint32  = "uint32"
55	GoInt64   = "int64"
56	GoUint64  = "uint64"
57	GoFloat32 = "float32"
58	GoFloat64 = "float64"
59	GoString  = "string"
60	GoBytes   = "[]byte"
61)
62
63func (g GoType) Zero() Expr {
64	switch g {
65	case GoBool:
66		return "false"
67	case GoString:
68		return `""`
69	case GoBytes:
70		return "nil"
71	}
72	return "0"
73}
74
75// Kind is the reflect.Kind of the type.
76func (g GoType) Kind() Expr {
77	if g == "" || g == GoBytes {
78		return ""
79	}
80	return "reflect." + Expr(strings.ToUpper(string(g[:1]))+string(g[1:]))
81}
82
83// PointerMethod is the "internal/impl".pointer method used to access a pointer to this type.
84func (g GoType) PointerMethod() Expr {
85	if g == GoBytes {
86		return "Bytes"
87	}
88	return Expr(strings.ToUpper(string(g[:1])) + string(g[1:]))
89}
90
91type ProtoKind struct {
92	Name     string
93	WireType WireType
94
95	// Conversions to/from protoreflect.Value.
96	ToValue   Expr
97	FromValue Expr
98
99	// Conversions to/from generated structures.
100	GoType         GoType
101	ToGoType       Expr
102	ToGoTypeNoZero Expr
103	FromGoType     Expr
104	NoPointer      bool
105	NoValueCodec   bool
106}
107
108func (k ProtoKind) Expr() Expr {
109	return "protoreflect." + Expr(k.Name) + "Kind"
110}
111
112var ProtoKinds = []ProtoKind{
113	{
114		Name:       "Bool",
115		WireType:   WireVarint,
116		ToValue:    "protoreflect.ValueOfBool(protowire.DecodeBool(v))",
117		FromValue:  "protowire.EncodeBool(v.Bool())",
118		GoType:     GoBool,
119		ToGoType:   "protowire.DecodeBool(v)",
120		FromGoType: "protowire.EncodeBool(v)",
121	},
122	{
123		Name:      "Enum",
124		WireType:  WireVarint,
125		ToValue:   "protoreflect.ValueOfEnum(protoreflect.EnumNumber(v))",
126		FromValue: "uint64(v.Enum())",
127	},
128	{
129		Name:       "Int32",
130		WireType:   WireVarint,
131		ToValue:    "protoreflect.ValueOfInt32(int32(v))",
132		FromValue:  "uint64(int32(v.Int()))",
133		GoType:     GoInt32,
134		ToGoType:   "int32(v)",
135		FromGoType: "uint64(v)",
136	},
137	{
138		Name:       "Sint32",
139		WireType:   WireVarint,
140		ToValue:    "protoreflect.ValueOfInt32(int32(protowire.DecodeZigZag(v & math.MaxUint32)))",
141		FromValue:  "protowire.EncodeZigZag(int64(int32(v.Int())))",
142		GoType:     GoInt32,
143		ToGoType:   "int32(protowire.DecodeZigZag(v & math.MaxUint32))",
144		FromGoType: "protowire.EncodeZigZag(int64(v))",
145	},
146	{
147		Name:       "Uint32",
148		WireType:   WireVarint,
149		ToValue:    "protoreflect.ValueOfUint32(uint32(v))",
150		FromValue:  "uint64(uint32(v.Uint()))",
151		GoType:     GoUint32,
152		ToGoType:   "uint32(v)",
153		FromGoType: "uint64(v)",
154	},
155	{
156		Name:       "Int64",
157		WireType:   WireVarint,
158		ToValue:    "protoreflect.ValueOfInt64(int64(v))",
159		FromValue:  "uint64(v.Int())",
160		GoType:     GoInt64,
161		ToGoType:   "int64(v)",
162		FromGoType: "uint64(v)",
163	},
164	{
165		Name:       "Sint64",
166		WireType:   WireVarint,
167		ToValue:    "protoreflect.ValueOfInt64(protowire.DecodeZigZag(v))",
168		FromValue:  "protowire.EncodeZigZag(v.Int())",
169		GoType:     GoInt64,
170		ToGoType:   "protowire.DecodeZigZag(v)",
171		FromGoType: "protowire.EncodeZigZag(v)",
172	},
173	{
174		Name:       "Uint64",
175		WireType:   WireVarint,
176		ToValue:    "protoreflect.ValueOfUint64(v)",
177		FromValue:  "v.Uint()",
178		GoType:     GoUint64,
179		ToGoType:   "v",
180		FromGoType: "v",
181	},
182	{
183		Name:       "Sfixed32",
184		WireType:   WireFixed32,
185		ToValue:    "protoreflect.ValueOfInt32(int32(v))",
186		FromValue:  "uint32(v.Int())",
187		GoType:     GoInt32,
188		ToGoType:   "int32(v)",
189		FromGoType: "uint32(v)",
190	},
191	{
192		Name:       "Fixed32",
193		WireType:   WireFixed32,
194		ToValue:    "protoreflect.ValueOfUint32(uint32(v))",
195		FromValue:  "uint32(v.Uint())",
196		GoType:     GoUint32,
197		ToGoType:   "v",
198		FromGoType: "v",
199	},
200	{
201		Name:       "Float",
202		WireType:   WireFixed32,
203		ToValue:    "protoreflect.ValueOfFloat32(math.Float32frombits(uint32(v)))",
204		FromValue:  "math.Float32bits(float32(v.Float()))",
205		GoType:     GoFloat32,
206		ToGoType:   "math.Float32frombits(v)",
207		FromGoType: "math.Float32bits(v)",
208	},
209	{
210		Name:       "Sfixed64",
211		WireType:   WireFixed64,
212		ToValue:    "protoreflect.ValueOfInt64(int64(v))",
213		FromValue:  "uint64(v.Int())",
214		GoType:     GoInt64,
215		ToGoType:   "int64(v)",
216		FromGoType: "uint64(v)",
217	},
218	{
219		Name:       "Fixed64",
220		WireType:   WireFixed64,
221		ToValue:    "protoreflect.ValueOfUint64(v)",
222		FromValue:  "v.Uint()",
223		GoType:     GoUint64,
224		ToGoType:   "v",
225		FromGoType: "v",
226	},
227	{
228		Name:       "Double",
229		WireType:   WireFixed64,
230		ToValue:    "protoreflect.ValueOfFloat64(math.Float64frombits(v))",
231		FromValue:  "math.Float64bits(v.Float())",
232		GoType:     GoFloat64,
233		ToGoType:   "math.Float64frombits(v)",
234		FromGoType: "math.Float64bits(v)",
235	},
236	{
237		Name:       "String",
238		WireType:   WireBytes,
239		ToValue:    "protoreflect.ValueOfString(string(v))",
240		FromValue:  "v.String()",
241		GoType:     GoString,
242		ToGoType:   "string(v)",
243		FromGoType: "v",
244	},
245	{
246		Name:           "Bytes",
247		WireType:       WireBytes,
248		ToValue:        "protoreflect.ValueOfBytes(append(emptyBuf[:], v...))",
249		FromValue:      "v.Bytes()",
250		GoType:         GoBytes,
251		ToGoType:       "append(emptyBuf[:], v...)",
252		ToGoTypeNoZero: "append(([]byte)(nil), v...)",
253		FromGoType:     "v",
254		NoPointer:      true,
255	},
256	{
257		Name:         "Message",
258		WireType:     WireBytes,
259		ToValue:      "protoreflect.ValueOfBytes(v)",
260		FromValue:    "v",
261		NoValueCodec: true,
262	},
263	{
264		Name:         "Group",
265		WireType:     WireGroup,
266		ToValue:      "protoreflect.ValueOfBytes(v)",
267		FromValue:    "v",
268		NoValueCodec: true,
269	},
270}
271
272func generateProtoDecode() string {
273	return mustExecute(protoDecodeTemplate, ProtoKinds)
274}
275
276var protoDecodeTemplate = template.Must(template.New("").Parse(`
277// unmarshalScalar decodes a value of the given kind.
278//
279// Message values are decoded into a []byte which aliases the input data.
280func (o UnmarshalOptions) unmarshalScalar(b []byte, wtyp protowire.Type, fd protoreflect.FieldDescriptor) (val protoreflect.Value, n int, err error) {
281	switch fd.Kind() {
282	{{- range .}}
283	case {{.Expr}}:
284		if wtyp != {{.WireType.Expr}} {
285			return val, 0, errUnknown
286		}
287		{{if (eq .WireType "Group") -}}
288		v, n := protowire.ConsumeGroup(fd.Number(), b)
289		{{- else -}}
290		v, n := protowire.Consume{{.WireType}}(b)
291		{{- end}}
292		if n < 0 {
293			return val, 0, errDecode
294		}
295		{{if (eq .Name "String") -}}
296		if strs.EnforceUTF8(fd) && !utf8.Valid(v) {
297			return protoreflect.Value{}, 0, errors.InvalidUTF8(string(fd.FullName()))
298		}
299		{{end -}}
300		return {{.ToValue}}, n, nil
301	{{- end}}
302	default:
303		return val, 0, errUnknown
304	}
305}
306
307func (o UnmarshalOptions) unmarshalList(b []byte, wtyp protowire.Type, list protoreflect.List, fd protoreflect.FieldDescriptor) (n int, err error) {
308	switch fd.Kind() {
309	{{- range .}}
310	case {{.Expr}}:
311		{{- if .WireType.Packable}}
312		if wtyp == protowire.BytesType {
313			buf, n := protowire.ConsumeBytes(b)
314			if n < 0 {
315				return 0, errDecode
316			}
317			for len(buf) > 0 {
318				v, n := protowire.Consume{{.WireType}}(buf)
319				if n < 0 {
320					return 0, errDecode
321				}
322				buf = buf[n:]
323				list.Append({{.ToValue}})
324			}
325			return n, nil
326		}
327		{{- end}}
328		if wtyp != {{.WireType.Expr}} {
329			return 0, errUnknown
330		}
331		{{if (eq .WireType "Group") -}}
332		v, n := protowire.ConsumeGroup(fd.Number(), b)
333		{{- else -}}
334		v, n := protowire.Consume{{.WireType}}(b)
335		{{- end}}
336		if n < 0 {
337			return 0, errDecode
338		}
339		{{if (eq .Name "String") -}}
340		if strs.EnforceUTF8(fd) && !utf8.Valid(v) {
341			return 0, errors.InvalidUTF8(string(fd.FullName()))
342		}
343		{{end -}}
344		{{if or (eq .Name "Message") (eq .Name "Group") -}}
345		m := list.NewElement()
346		if err := o.unmarshalMessage(v, m.Message()); err != nil {
347			return 0, err
348		}
349		list.Append(m)
350		{{- else -}}
351		list.Append({{.ToValue}})
352		{{- end}}
353		return n, nil
354	{{- end}}
355	default:
356		return 0, errUnknown
357	}
358}
359
360// We append to an empty array rather than a nil []byte to get non-nil zero-length byte slices.
361var emptyBuf [0]byte
362`))
363
364func generateProtoEncode() string {
365	return mustExecute(protoEncodeTemplate, ProtoKinds)
366}
367
368var protoEncodeTemplate = template.Must(template.New("").Parse(`
369var wireTypes = map[protoreflect.Kind]protowire.Type{
370{{- range .}}
371	{{.Expr}}: {{.WireType.Expr}},
372{{- end}}
373}
374
375func (o MarshalOptions) marshalSingular(b []byte, fd protoreflect.FieldDescriptor, v protoreflect.Value) ([]byte, error) {
376	switch fd.Kind() {
377	{{- range .}}
378	case {{.Expr}}:
379		{{- if (eq .Name "String") }}
380		if strs.EnforceUTF8(fd) && !utf8.ValidString(v.String()) {
381			return b, errors.InvalidUTF8(string(fd.FullName()))
382		}
383		b = protowire.AppendString(b, {{.FromValue}})
384		{{- else if (eq .Name "Message") -}}
385		var pos int
386		var err error
387		b, pos = appendSpeculativeLength(b)
388		b, err = o.marshalMessage(b, v.Message())
389		if err != nil {
390			return b, err
391		}
392		b = finishSpeculativeLength(b, pos)
393		{{- else if (eq .Name "Group") -}}
394		var err error
395		b, err = o.marshalMessage(b, v.Message())
396		if err != nil {
397			return b, err
398		}
399		b = protowire.AppendVarint(b, protowire.EncodeTag(fd.Number(), protowire.EndGroupType))
400		{{- else -}}
401		b = protowire.Append{{.WireType}}(b, {{.FromValue}})
402		{{- end}}
403	{{- end}}
404	default:
405		return b, errors.New("invalid kind %v", fd.Kind())
406	}
407	return b, nil
408}
409`))
410
411func generateProtoSize() string {
412	return mustExecute(protoSizeTemplate, ProtoKinds)
413}
414
415var protoSizeTemplate = template.Must(template.New("").Parse(`
416func (o MarshalOptions) sizeSingular(num protowire.Number, kind protoreflect.Kind, v protoreflect.Value) int {
417	switch kind {
418	{{- range .}}
419	case {{.Expr}}:
420		{{if (eq .Name "Message") -}}
421		return protowire.SizeBytes(o.size(v.Message()))
422		{{- else if or (eq .WireType "Fixed32") (eq .WireType "Fixed64") -}}
423		return protowire.Size{{.WireType}}()
424		{{- else if (eq .WireType "Bytes") -}}
425		return protowire.Size{{.WireType}}(len({{.FromValue}}))
426		{{- else if (eq .WireType "Group") -}}
427		return protowire.Size{{.WireType}}(num, o.size(v.Message()))
428		{{- else -}}
429		return protowire.Size{{.WireType}}({{.FromValue}})
430		{{- end}}
431	{{- end}}
432	default:
433		return 0
434	}
435}
436`))
437