1// Copyright 2014 Google Inc. All Rights Reserved.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15// This file is a simple protocol buffer encoder and decoder.
16// The format is described at
17// https://developers.google.com/protocol-buffers/docs/encoding
18//
19// A protocol message must implement the message interface:
20//   decoder() []decoder
21//   encode(*buffer)
22//
23// The decode method returns a slice indexed by field number that gives the
24// function to decode that field.
25// The encode method encodes its receiver into the given buffer.
26//
27// The two methods are simple enough to be implemented by hand rather than
28// by using a protocol compiler.
29//
30// See profile.go for examples of messages implementing this interface.
31//
32// There is no support for groups, message sets, or "has" bits.
33
34package profile
35
36import "errors"
37
38type buffer struct {
39	field int // field tag
40	typ   int // proto wire type code for field
41	u64   uint64
42	data  []byte
43	tmp   [16]byte
44}
45
46type decoder func(*buffer, message) error
47
48type message interface {
49	decoder() []decoder
50	encode(*buffer)
51}
52
53func marshal(m message) []byte {
54	var b buffer
55	m.encode(&b)
56	return b.data
57}
58
59func encodeVarint(b *buffer, x uint64) {
60	for x >= 128 {
61		b.data = append(b.data, byte(x)|0x80)
62		x >>= 7
63	}
64	b.data = append(b.data, byte(x))
65}
66
67func encodeLength(b *buffer, tag int, len int) {
68	encodeVarint(b, uint64(tag)<<3|2)
69	encodeVarint(b, uint64(len))
70}
71
72func encodeUint64(b *buffer, tag int, x uint64) {
73	// append varint to b.data
74	encodeVarint(b, uint64(tag)<<3|0)
75	encodeVarint(b, x)
76}
77
78func encodeUint64s(b *buffer, tag int, x []uint64) {
79	if len(x) > 2 {
80		// Use packed encoding
81		n1 := len(b.data)
82		for _, u := range x {
83			encodeVarint(b, u)
84		}
85		n2 := len(b.data)
86		encodeLength(b, tag, n2-n1)
87		n3 := len(b.data)
88		copy(b.tmp[:], b.data[n2:n3])
89		copy(b.data[n1+(n3-n2):], b.data[n1:n2])
90		copy(b.data[n1:], b.tmp[:n3-n2])
91		return
92	}
93	for _, u := range x {
94		encodeUint64(b, tag, u)
95	}
96}
97
98func encodeUint64Opt(b *buffer, tag int, x uint64) {
99	if x == 0 {
100		return
101	}
102	encodeUint64(b, tag, x)
103}
104
105func encodeInt64(b *buffer, tag int, x int64) {
106	u := uint64(x)
107	encodeUint64(b, tag, u)
108}
109
110func encodeInt64s(b *buffer, tag int, x []int64) {
111	if len(x) > 2 {
112		// Use packed encoding
113		n1 := len(b.data)
114		for _, u := range x {
115			encodeVarint(b, uint64(u))
116		}
117		n2 := len(b.data)
118		encodeLength(b, tag, n2-n1)
119		n3 := len(b.data)
120		copy(b.tmp[:], b.data[n2:n3])
121		copy(b.data[n1+(n3-n2):], b.data[n1:n2])
122		copy(b.data[n1:], b.tmp[:n3-n2])
123		return
124	}
125	for _, u := range x {
126		encodeInt64(b, tag, u)
127	}
128}
129
130func encodeInt64Opt(b *buffer, tag int, x int64) {
131	if x == 0 {
132		return
133	}
134	encodeInt64(b, tag, x)
135}
136
137func encodeString(b *buffer, tag int, x string) {
138	encodeLength(b, tag, len(x))
139	b.data = append(b.data, x...)
140}
141
142func encodeStrings(b *buffer, tag int, x []string) {
143	for _, s := range x {
144		encodeString(b, tag, s)
145	}
146}
147
148func encodeStringOpt(b *buffer, tag int, x string) {
149	if x == "" {
150		return
151	}
152	encodeString(b, tag, x)
153}
154
155func encodeBool(b *buffer, tag int, x bool) {
156	if x {
157		encodeUint64(b, tag, 1)
158	} else {
159		encodeUint64(b, tag, 0)
160	}
161}
162
163func encodeBoolOpt(b *buffer, tag int, x bool) {
164	if x == false {
165		return
166	}
167	encodeBool(b, tag, x)
168}
169
170func encodeMessage(b *buffer, tag int, m message) {
171	n1 := len(b.data)
172	m.encode(b)
173	n2 := len(b.data)
174	encodeLength(b, tag, n2-n1)
175	n3 := len(b.data)
176	copy(b.tmp[:], b.data[n2:n3])
177	copy(b.data[n1+(n3-n2):], b.data[n1:n2])
178	copy(b.data[n1:], b.tmp[:n3-n2])
179}
180
181func unmarshal(data []byte, m message) (err error) {
182	b := buffer{data: data, typ: 2}
183	return decodeMessage(&b, m)
184}
185
186func le64(p []byte) uint64 {
187	return uint64(p[0]) | uint64(p[1])<<8 | uint64(p[2])<<16 | uint64(p[3])<<24 | uint64(p[4])<<32 | uint64(p[5])<<40 | uint64(p[6])<<48 | uint64(p[7])<<56
188}
189
190func le32(p []byte) uint32 {
191	return uint32(p[0]) | uint32(p[1])<<8 | uint32(p[2])<<16 | uint32(p[3])<<24
192}
193
194func decodeVarint(data []byte) (uint64, []byte, error) {
195	var u uint64
196	for i := 0; ; i++ {
197		if i >= 10 || i >= len(data) {
198			return 0, nil, errors.New("bad varint")
199		}
200		u |= uint64(data[i]&0x7F) << uint(7*i)
201		if data[i]&0x80 == 0 {
202			return u, data[i+1:], nil
203		}
204	}
205}
206
207func decodeField(b *buffer, data []byte) ([]byte, error) {
208	x, data, err := decodeVarint(data)
209	if err != nil {
210		return nil, err
211	}
212	b.field = int(x >> 3)
213	b.typ = int(x & 7)
214	b.data = nil
215	b.u64 = 0
216	switch b.typ {
217	case 0:
218		b.u64, data, err = decodeVarint(data)
219		if err != nil {
220			return nil, err
221		}
222	case 1:
223		if len(data) < 8 {
224			return nil, errors.New("not enough data")
225		}
226		b.u64 = le64(data[:8])
227		data = data[8:]
228	case 2:
229		var n uint64
230		n, data, err = decodeVarint(data)
231		if err != nil {
232			return nil, err
233		}
234		if n > uint64(len(data)) {
235			return nil, errors.New("too much data")
236		}
237		b.data = data[:n]
238		data = data[n:]
239	case 5:
240		if len(data) < 4 {
241			return nil, errors.New("not enough data")
242		}
243		b.u64 = uint64(le32(data[:4]))
244		data = data[4:]
245	default:
246		return nil, errors.New("unknown wire type: " + string(b.typ))
247	}
248
249	return data, nil
250}
251
252func checkType(b *buffer, typ int) error {
253	if b.typ != typ {
254		return errors.New("type mismatch")
255	}
256	return nil
257}
258
259func decodeMessage(b *buffer, m message) error {
260	if err := checkType(b, 2); err != nil {
261		return err
262	}
263	dec := m.decoder()
264	data := b.data
265	for len(data) > 0 {
266		// pull varint field# + type
267		var err error
268		data, err = decodeField(b, data)
269		if err != nil {
270			return err
271		}
272		if b.field >= len(dec) || dec[b.field] == nil {
273			continue
274		}
275		if err := dec[b.field](b, m); err != nil {
276			return err
277		}
278	}
279	return nil
280}
281
282func decodeInt64(b *buffer, x *int64) error {
283	if err := checkType(b, 0); err != nil {
284		return err
285	}
286	*x = int64(b.u64)
287	return nil
288}
289
290func decodeInt64s(b *buffer, x *[]int64) error {
291	if b.typ == 2 {
292		// Packed encoding
293		data := b.data
294		tmp := make([]int64, 0, len(data)) // Maximally sized
295		for len(data) > 0 {
296			var u uint64
297			var err error
298
299			if u, data, err = decodeVarint(data); err != nil {
300				return err
301			}
302			tmp = append(tmp, int64(u))
303		}
304		*x = append(*x, tmp...)
305		return nil
306	}
307	var i int64
308	if err := decodeInt64(b, &i); err != nil {
309		return err
310	}
311	*x = append(*x, i)
312	return nil
313}
314
315func decodeUint64(b *buffer, x *uint64) error {
316	if err := checkType(b, 0); err != nil {
317		return err
318	}
319	*x = b.u64
320	return nil
321}
322
323func decodeUint64s(b *buffer, x *[]uint64) error {
324	if b.typ == 2 {
325		data := b.data
326		// Packed encoding
327		tmp := make([]uint64, 0, len(data)) // Maximally sized
328		for len(data) > 0 {
329			var u uint64
330			var err error
331
332			if u, data, err = decodeVarint(data); err != nil {
333				return err
334			}
335			tmp = append(tmp, u)
336		}
337		*x = append(*x, tmp...)
338		return nil
339	}
340	var u uint64
341	if err := decodeUint64(b, &u); err != nil {
342		return err
343	}
344	*x = append(*x, u)
345	return nil
346}
347
348func decodeString(b *buffer, x *string) error {
349	if err := checkType(b, 2); err != nil {
350		return err
351	}
352	*x = string(b.data)
353	return nil
354}
355
356func decodeStrings(b *buffer, x *[]string) error {
357	var s string
358	if err := decodeString(b, &s); err != nil {
359		return err
360	}
361	*x = append(*x, s)
362	return nil
363}
364
365func decodeBool(b *buffer, x *bool) error {
366	if err := checkType(b, 0); err != nil {
367		return err
368	}
369	if int64(b.u64) == 0 {
370		*x = false
371	} else {
372		*x = true
373	}
374	return nil
375}
376