1// Go support for Protocol Buffers - Google's data interchange format
2//
3// Copyright 2010 The Go Authors.  All rights reserved.
4// https://github.com/golang/protobuf
5//
6// Redistribution and use in source and binary forms, with or without
7// modification, are permitted provided that the following conditions are
8// met:
9//
10//     * Redistributions of source code must retain the above copyright
11// notice, this list of conditions and the following disclaimer.
12//     * Redistributions in binary form must reproduce the above
13// copyright notice, this list of conditions and the following disclaimer
14// in the documentation and/or other materials provided with the
15// distribution.
16//     * Neither the name of Google Inc. nor the names of its
17// contributors may be used to endorse or promote products derived from
18// this software without specific prior written permission.
19//
20// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
21// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
22// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
23// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
24// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
25// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
26// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
27// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
28// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
29// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
30// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
31
32package proto
33
34/*
35 * Support for message sets.
36 */
37
38import (
39	"bytes"
40	"encoding/json"
41	"errors"
42	"fmt"
43	"reflect"
44	"sort"
45)
46
47// errNoMessageTypeID occurs when a protocol buffer does not have a message type ID.
48// A message type ID is required for storing a protocol buffer in a message set.
49var errNoMessageTypeID = errors.New("proto does not have a message type ID")
50
51// The first two types (_MessageSet_Item and messageSet)
52// model what the protocol compiler produces for the following protocol message:
53//   message MessageSet {
54//     repeated group Item = 1 {
55//       required int32 type_id = 2;
56//       required string message = 3;
57//     };
58//   }
59// That is the MessageSet wire format. We can't use a proto to generate these
60// because that would introduce a circular dependency between it and this package.
61
62type _MessageSet_Item struct {
63	TypeId  *int32 `protobuf:"varint,2,req,name=type_id"`
64	Message []byte `protobuf:"bytes,3,req,name=message"`
65}
66
67type messageSet struct {
68	Item             []*_MessageSet_Item `protobuf:"group,1,rep"`
69	XXX_unrecognized []byte
70	// TODO: caching?
71}
72
73// Make sure messageSet is a Message.
74var _ Message = (*messageSet)(nil)
75
76// messageTypeIder is an interface satisfied by a protocol buffer type
77// that may be stored in a MessageSet.
78type messageTypeIder interface {
79	MessageTypeId() int32
80}
81
82func (ms *messageSet) find(pb Message) *_MessageSet_Item {
83	mti, ok := pb.(messageTypeIder)
84	if !ok {
85		return nil
86	}
87	id := mti.MessageTypeId()
88	for _, item := range ms.Item {
89		if *item.TypeId == id {
90			return item
91		}
92	}
93	return nil
94}
95
96func (ms *messageSet) Has(pb Message) bool {
97	if ms.find(pb) != nil {
98		return true
99	}
100	return false
101}
102
103func (ms *messageSet) Unmarshal(pb Message) error {
104	if item := ms.find(pb); item != nil {
105		return Unmarshal(item.Message, pb)
106	}
107	if _, ok := pb.(messageTypeIder); !ok {
108		return errNoMessageTypeID
109	}
110	return nil // TODO: return error instead?
111}
112
113func (ms *messageSet) Marshal(pb Message) error {
114	msg, err := Marshal(pb)
115	if err != nil {
116		return err
117	}
118	if item := ms.find(pb); item != nil {
119		// reuse existing item
120		item.Message = msg
121		return nil
122	}
123
124	mti, ok := pb.(messageTypeIder)
125	if !ok {
126		return errNoMessageTypeID
127	}
128
129	mtid := mti.MessageTypeId()
130	ms.Item = append(ms.Item, &_MessageSet_Item{
131		TypeId:  &mtid,
132		Message: msg,
133	})
134	return nil
135}
136
137func (ms *messageSet) Reset()         { *ms = messageSet{} }
138func (ms *messageSet) String() string { return CompactTextString(ms) }
139func (*messageSet) ProtoMessage()     {}
140
141// Support for the message_set_wire_format message option.
142
143func skipVarint(buf []byte) []byte {
144	i := 0
145	for ; buf[i]&0x80 != 0; i++ {
146	}
147	return buf[i+1:]
148}
149
150// MarshalMessageSet encodes the extension map represented by m in the message set wire format.
151// It is called by generated Marshal methods on protocol buffer messages with the message_set_wire_format option.
152func MarshalMessageSet(exts interface{}) ([]byte, error) {
153	var m map[int32]Extension
154	switch exts := exts.(type) {
155	case *XXX_InternalExtensions:
156		if err := encodeExtensions(exts); err != nil {
157			return nil, err
158		}
159		m, _ = exts.extensionsRead()
160	case map[int32]Extension:
161		if err := encodeExtensionsMap(exts); err != nil {
162			return nil, err
163		}
164		m = exts
165	default:
166		return nil, errors.New("proto: not an extension map")
167	}
168
169	// Sort extension IDs to provide a deterministic encoding.
170	// See also enc_map in encode.go.
171	ids := make([]int, 0, len(m))
172	for id := range m {
173		ids = append(ids, int(id))
174	}
175	sort.Ints(ids)
176
177	ms := &messageSet{Item: make([]*_MessageSet_Item, 0, len(m))}
178	for _, id := range ids {
179		e := m[int32(id)]
180		// Remove the wire type and field number varint, as well as the length varint.
181		msg := skipVarint(skipVarint(e.enc))
182
183		ms.Item = append(ms.Item, &_MessageSet_Item{
184			TypeId:  Int32(int32(id)),
185			Message: msg,
186		})
187	}
188	return Marshal(ms)
189}
190
191// UnmarshalMessageSet decodes the extension map encoded in buf in the message set wire format.
192// It is called by generated Unmarshal methods on protocol buffer messages with the message_set_wire_format option.
193func UnmarshalMessageSet(buf []byte, exts interface{}) error {
194	var m map[int32]Extension
195	switch exts := exts.(type) {
196	case *XXX_InternalExtensions:
197		m = exts.extensionsWrite()
198	case map[int32]Extension:
199		m = exts
200	default:
201		return errors.New("proto: not an extension map")
202	}
203
204	ms := new(messageSet)
205	if err := Unmarshal(buf, ms); err != nil {
206		return err
207	}
208	for _, item := range ms.Item {
209		id := *item.TypeId
210		msg := item.Message
211
212		// Restore wire type and field number varint, plus length varint.
213		// Be careful to preserve duplicate items.
214		b := EncodeVarint(uint64(id)<<3 | WireBytes)
215		if ext, ok := m[id]; ok {
216			// Existing data; rip off the tag and length varint
217			// so we join the new data correctly.
218			// We can assume that ext.enc is set because we are unmarshaling.
219			o := ext.enc[len(b):]   // skip wire type and field number
220			_, n := DecodeVarint(o) // calculate length of length varint
221			o = o[n:]               // skip length varint
222			msg = append(o, msg...) // join old data and new data
223		}
224		b = append(b, EncodeVarint(uint64(len(msg)))...)
225		b = append(b, msg...)
226
227		m[id] = Extension{enc: b}
228	}
229	return nil
230}
231
232// MarshalMessageSetJSON encodes the extension map represented by m in JSON format.
233// It is called by generated MarshalJSON methods on protocol buffer messages with the message_set_wire_format option.
234func MarshalMessageSetJSON(exts interface{}) ([]byte, error) {
235	var m map[int32]Extension
236	switch exts := exts.(type) {
237	case *XXX_InternalExtensions:
238		m, _ = exts.extensionsRead()
239	case map[int32]Extension:
240		m = exts
241	default:
242		return nil, errors.New("proto: not an extension map")
243	}
244	var b bytes.Buffer
245	b.WriteByte('{')
246
247	// Process the map in key order for deterministic output.
248	ids := make([]int32, 0, len(m))
249	for id := range m {
250		ids = append(ids, id)
251	}
252	sort.Sort(int32Slice(ids)) // int32Slice defined in text.go
253
254	for i, id := range ids {
255		ext := m[id]
256		if i > 0 {
257			b.WriteByte(',')
258		}
259
260		msd, ok := messageSetMap[id]
261		if !ok {
262			// Unknown type; we can't render it, so skip it.
263			continue
264		}
265		fmt.Fprintf(&b, `"[%s]":`, msd.name)
266
267		x := ext.value
268		if x == nil {
269			x = reflect.New(msd.t.Elem()).Interface()
270			if err := Unmarshal(ext.enc, x.(Message)); err != nil {
271				return nil, err
272			}
273		}
274		d, err := json.Marshal(x)
275		if err != nil {
276			return nil, err
277		}
278		b.Write(d)
279	}
280	b.WriteByte('}')
281	return b.Bytes(), nil
282}
283
284// UnmarshalMessageSetJSON decodes the extension map encoded in buf in JSON format.
285// It is called by generated UnmarshalJSON methods on protocol buffer messages with the message_set_wire_format option.
286func UnmarshalMessageSetJSON(buf []byte, exts interface{}) error {
287	// Common-case fast path.
288	if len(buf) == 0 || bytes.Equal(buf, []byte("{}")) {
289		return nil
290	}
291
292	// This is fairly tricky, and it's not clear that it is needed.
293	return errors.New("TODO: UnmarshalMessageSetJSON not yet implemented")
294}
295
296// A global registry of types that can be used in a MessageSet.
297
298var messageSetMap = make(map[int32]messageSetDesc)
299
300type messageSetDesc struct {
301	t    reflect.Type // pointer to struct
302	name string
303}
304
305// RegisterMessageSetType is called from the generated code.
306func RegisterMessageSetType(m Message, fieldNum int32, name string) {
307	messageSetMap[fieldNum] = messageSetDesc{
308		t:    reflect.TypeOf(m),
309		name: name,
310	}
311}
312