1// Copyright 2019 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package proto
6
7import (
8	"errors"
9	"fmt"
10
11	"google.golang.org/protobuf/encoding/prototext"
12	"google.golang.org/protobuf/encoding/protowire"
13	"google.golang.org/protobuf/runtime/protoimpl"
14)
15
16const (
17	WireVarint     = 0
18	WireFixed32    = 5
19	WireFixed64    = 1
20	WireBytes      = 2
21	WireStartGroup = 3
22	WireEndGroup   = 4
23)
24
25// EncodeVarint returns the varint encoded bytes of v.
26func EncodeVarint(v uint64) []byte {
27	return protowire.AppendVarint(nil, v)
28}
29
30// SizeVarint returns the length of the varint encoded bytes of v.
31// This is equal to len(EncodeVarint(v)).
32func SizeVarint(v uint64) int {
33	return protowire.SizeVarint(v)
34}
35
36// DecodeVarint parses a varint encoded integer from b,
37// returning the integer value and the length of the varint.
38// It returns (0, 0) if there is a parse error.
39func DecodeVarint(b []byte) (uint64, int) {
40	v, n := protowire.ConsumeVarint(b)
41	if n < 0 {
42		return 0, 0
43	}
44	return v, n
45}
46
47// Buffer is a buffer for encoding and decoding the protobuf wire format.
48// It may be reused between invocations to reduce memory usage.
49type Buffer struct {
50	buf           []byte
51	idx           int
52	deterministic bool
53}
54
55// NewBuffer allocates a new Buffer initialized with buf,
56// where the contents of buf are considered the unread portion of the buffer.
57func NewBuffer(buf []byte) *Buffer {
58	return &Buffer{buf: buf}
59}
60
61// SetDeterministic specifies whether to use deterministic serialization.
62//
63// Deterministic serialization guarantees that for a given binary, equal
64// messages will always be serialized to the same bytes. This implies:
65//
66//   - Repeated serialization of a message will return the same bytes.
67//   - Different processes of the same binary (which may be executing on
68//     different machines) will serialize equal messages to the same bytes.
69//
70// Note that the deterministic serialization is NOT canonical across
71// languages. It is not guaranteed to remain stable over time. It is unstable
72// across different builds with schema changes due to unknown fields.
73// Users who need canonical serialization (e.g., persistent storage in a
74// canonical form, fingerprinting, etc.) should define their own
75// canonicalization specification and implement their own serializer rather
76// than relying on this API.
77//
78// If deterministic serialization is requested, map entries will be sorted
79// by keys in lexographical order. This is an implementation detail and
80// subject to change.
81func (b *Buffer) SetDeterministic(deterministic bool) {
82	b.deterministic = deterministic
83}
84
85// SetBuf sets buf as the internal buffer,
86// where the contents of buf are considered the unread portion of the buffer.
87func (b *Buffer) SetBuf(buf []byte) {
88	b.buf = buf
89	b.idx = 0
90}
91
92// Reset clears the internal buffer of all written and unread data.
93func (b *Buffer) Reset() {
94	b.buf = b.buf[:0]
95	b.idx = 0
96}
97
98// Bytes returns the internal buffer.
99func (b *Buffer) Bytes() []byte {
100	return b.buf
101}
102
103// Unread returns the unread portion of the buffer.
104func (b *Buffer) Unread() []byte {
105	return b.buf[b.idx:]
106}
107
108// Marshal appends the wire-format encoding of m to the buffer.
109func (b *Buffer) Marshal(m Message) error {
110	var err error
111	b.buf, err = marshalAppend(b.buf, m, b.deterministic)
112	return err
113}
114
115// Unmarshal parses the wire-format message in the buffer and
116// places the decoded results in m.
117// It does not reset m before unmarshaling.
118func (b *Buffer) Unmarshal(m Message) error {
119	err := UnmarshalMerge(b.Unread(), m)
120	b.idx = len(b.buf)
121	return err
122}
123
124type unknownFields struct{ XXX_unrecognized protoimpl.UnknownFields }
125
126func (m *unknownFields) String() string { panic("not implemented") }
127func (m *unknownFields) Reset()         { panic("not implemented") }
128func (m *unknownFields) ProtoMessage()  { panic("not implemented") }
129
130// DebugPrint dumps the encoded bytes of b with a header and footer including s
131// to stdout. This is only intended for debugging.
132func (*Buffer) DebugPrint(s string, b []byte) {
133	m := MessageReflect(new(unknownFields))
134	m.SetUnknown(b)
135	b, _ = prototext.MarshalOptions{AllowPartial: true, Indent: "\t"}.Marshal(m.Interface())
136	fmt.Printf("==== %s ====\n%s==== %s ====\n", s, b, s)
137}
138
139// EncodeVarint appends an unsigned varint encoding to the buffer.
140func (b *Buffer) EncodeVarint(v uint64) error {
141	b.buf = protowire.AppendVarint(b.buf, v)
142	return nil
143}
144
145// EncodeZigzag32 appends a 32-bit zig-zag varint encoding to the buffer.
146func (b *Buffer) EncodeZigzag32(v uint64) error {
147	return b.EncodeVarint(uint64((uint32(v) << 1) ^ uint32((int32(v) >> 31))))
148}
149
150// EncodeZigzag64 appends a 64-bit zig-zag varint encoding to the buffer.
151func (b *Buffer) EncodeZigzag64(v uint64) error {
152	return b.EncodeVarint(uint64((uint64(v) << 1) ^ uint64((int64(v) >> 63))))
153}
154
155// EncodeFixed32 appends a 32-bit little-endian integer to the buffer.
156func (b *Buffer) EncodeFixed32(v uint64) error {
157	b.buf = protowire.AppendFixed32(b.buf, uint32(v))
158	return nil
159}
160
161// EncodeFixed64 appends a 64-bit little-endian integer to the buffer.
162func (b *Buffer) EncodeFixed64(v uint64) error {
163	b.buf = protowire.AppendFixed64(b.buf, uint64(v))
164	return nil
165}
166
167// EncodeRawBytes appends a length-prefixed raw bytes to the buffer.
168func (b *Buffer) EncodeRawBytes(v []byte) error {
169	b.buf = protowire.AppendBytes(b.buf, v)
170	return nil
171}
172
173// EncodeStringBytes appends a length-prefixed raw bytes to the buffer.
174// It does not validate whether v contains valid UTF-8.
175func (b *Buffer) EncodeStringBytes(v string) error {
176	b.buf = protowire.AppendString(b.buf, v)
177	return nil
178}
179
180// EncodeMessage appends a length-prefixed encoded message to the buffer.
181func (b *Buffer) EncodeMessage(m Message) error {
182	var err error
183	b.buf = protowire.AppendVarint(b.buf, uint64(Size(m)))
184	b.buf, err = marshalAppend(b.buf, m, b.deterministic)
185	return err
186}
187
188// DecodeVarint consumes an encoded unsigned varint from the buffer.
189func (b *Buffer) DecodeVarint() (uint64, error) {
190	v, n := protowire.ConsumeVarint(b.buf[b.idx:])
191	if n < 0 {
192		return 0, protowire.ParseError(n)
193	}
194	b.idx += n
195	return uint64(v), nil
196}
197
198// DecodeZigzag32 consumes an encoded 32-bit zig-zag varint from the buffer.
199func (b *Buffer) DecodeZigzag32() (uint64, error) {
200	v, err := b.DecodeVarint()
201	if err != nil {
202		return 0, err
203	}
204	return uint64((uint32(v) >> 1) ^ uint32((int32(v&1)<<31)>>31)), nil
205}
206
207// DecodeZigzag64 consumes an encoded 64-bit zig-zag varint from the buffer.
208func (b *Buffer) DecodeZigzag64() (uint64, error) {
209	v, err := b.DecodeVarint()
210	if err != nil {
211		return 0, err
212	}
213	return uint64((uint64(v) >> 1) ^ uint64((int64(v&1)<<63)>>63)), nil
214}
215
216// DecodeFixed32 consumes a 32-bit little-endian integer from the buffer.
217func (b *Buffer) DecodeFixed32() (uint64, error) {
218	v, n := protowire.ConsumeFixed32(b.buf[b.idx:])
219	if n < 0 {
220		return 0, protowire.ParseError(n)
221	}
222	b.idx += n
223	return uint64(v), nil
224}
225
226// DecodeFixed64 consumes a 64-bit little-endian integer from the buffer.
227func (b *Buffer) DecodeFixed64() (uint64, error) {
228	v, n := protowire.ConsumeFixed64(b.buf[b.idx:])
229	if n < 0 {
230		return 0, protowire.ParseError(n)
231	}
232	b.idx += n
233	return uint64(v), nil
234}
235
236// DecodeRawBytes consumes a length-prefixed raw bytes from the buffer.
237// If alloc is specified, it returns a copy the raw bytes
238// rather than a sub-slice of the buffer.
239func (b *Buffer) DecodeRawBytes(alloc bool) ([]byte, error) {
240	v, n := protowire.ConsumeBytes(b.buf[b.idx:])
241	if n < 0 {
242		return nil, protowire.ParseError(n)
243	}
244	b.idx += n
245	if alloc {
246		v = append([]byte(nil), v...)
247	}
248	return v, nil
249}
250
251// DecodeStringBytes consumes a length-prefixed raw bytes from the buffer.
252// It does not validate whether the raw bytes contain valid UTF-8.
253func (b *Buffer) DecodeStringBytes() (string, error) {
254	v, n := protowire.ConsumeString(b.buf[b.idx:])
255	if n < 0 {
256		return "", protowire.ParseError(n)
257	}
258	b.idx += n
259	return v, nil
260}
261
262// DecodeMessage consumes a length-prefixed message from the buffer.
263// It does not reset m before unmarshaling.
264func (b *Buffer) DecodeMessage(m Message) error {
265	v, err := b.DecodeRawBytes(false)
266	if err != nil {
267		return err
268	}
269	return UnmarshalMerge(v, m)
270}
271
272// DecodeGroup consumes a message group from the buffer.
273// It assumes that the start group marker has already been consumed and
274// consumes all bytes until (and including the end group marker).
275// It does not reset m before unmarshaling.
276func (b *Buffer) DecodeGroup(m Message) error {
277	v, n, err := consumeGroup(b.buf[b.idx:])
278	if err != nil {
279		return err
280	}
281	b.idx += n
282	return UnmarshalMerge(v, m)
283}
284
285// consumeGroup parses b until it finds an end group marker, returning
286// the raw bytes of the message (excluding the end group marker) and the
287// the total length of the message (including the end group marker).
288func consumeGroup(b []byte) ([]byte, int, error) {
289	b0 := b
290	depth := 1 // assume this follows a start group marker
291	for {
292		_, wtyp, tagLen := protowire.ConsumeTag(b)
293		if tagLen < 0 {
294			return nil, 0, protowire.ParseError(tagLen)
295		}
296		b = b[tagLen:]
297
298		var valLen int
299		switch wtyp {
300		case protowire.VarintType:
301			_, valLen = protowire.ConsumeVarint(b)
302		case protowire.Fixed32Type:
303			_, valLen = protowire.ConsumeFixed32(b)
304		case protowire.Fixed64Type:
305			_, valLen = protowire.ConsumeFixed64(b)
306		case protowire.BytesType:
307			_, valLen = protowire.ConsumeBytes(b)
308		case protowire.StartGroupType:
309			depth++
310		case protowire.EndGroupType:
311			depth--
312		default:
313			return nil, 0, errors.New("proto: cannot parse reserved wire type")
314		}
315		if valLen < 0 {
316			return nil, 0, protowire.ParseError(valLen)
317		}
318		b = b[valLen:]
319
320		if depth == 0 {
321			return b0[:len(b0)-len(b)-tagLen], len(b0) - len(b), nil
322		}
323	}
324}
325