1// Go support for Protocol Buffers - Google's data interchange format
2//
3// Copyright 2010 The Go Authors.  All rights reserved.
4// https://github.com/golang/protobuf
5//
6// Redistribution and use in source and binary forms, with or without
7// modification, are permitted provided that the following conditions are
8// met:
9//
10//     * Redistributions of source code must retain the above copyright
11// notice, this list of conditions and the following disclaimer.
12//     * Redistributions in binary form must reproduce the above
13// copyright notice, this list of conditions and the following disclaimer
14// in the documentation and/or other materials provided with the
15// distribution.
16//     * Neither the name of Google Inc. nor the names of its
17// contributors may be used to endorse or promote products derived from
18// this software without specific prior written permission.
19//
20// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
21// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
22// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
23// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
24// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
25// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
26// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
27// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
28// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
29// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
30// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
31
32package proto
33
34/*
35 * Types and routines for supporting protocol buffer extensions.
36 */
37
38import (
39	"errors"
40	"fmt"
41	"reflect"
42	"strconv"
43	"sync"
44)
45
46// ErrMissingExtension is the error returned by GetExtension if the named extension is not in the message.
47var ErrMissingExtension = errors.New("proto: missing extension")
48
49// ExtensionRange represents a range of message extensions for a protocol buffer.
50// Used in code generated by the protocol compiler.
51type ExtensionRange struct {
52	Start, End int32 // both inclusive
53}
54
55// extendableProto is an interface implemented by any protocol buffer that may be extended.
56type extendableProto interface {
57	Message
58	ExtensionRangeArray() []ExtensionRange
59	ExtensionMap() map[int32]Extension
60}
61
62var extendableProtoType = reflect.TypeOf((*extendableProto)(nil)).Elem()
63
64// ExtensionDesc represents an extension specification.
65// Used in generated code from the protocol compiler.
66type ExtensionDesc struct {
67	ExtendedType  Message     // nil pointer to the type that is being extended
68	ExtensionType interface{} // nil pointer to the extension type
69	Field         int32       // field number
70	Name          string      // fully-qualified name of extension, for text formatting
71	Tag           string      // protobuf tag style
72}
73
74func (ed *ExtensionDesc) repeated() bool {
75	t := reflect.TypeOf(ed.ExtensionType)
76	return t.Kind() == reflect.Slice && t.Elem().Kind() != reflect.Uint8
77}
78
79// Extension represents an extension in a message.
80type Extension struct {
81	// When an extension is stored in a message using SetExtension
82	// only desc and value are set. When the message is marshaled
83	// enc will be set to the encoded form of the message.
84	//
85	// When a message is unmarshaled and contains extensions, each
86	// extension will have only enc set. When such an extension is
87	// accessed using GetExtension (or GetExtensions) desc and value
88	// will be set.
89	desc  *ExtensionDesc
90	value interface{}
91	enc   []byte
92}
93
94// SetRawExtension is for testing only.
95func SetRawExtension(base extendableProto, id int32, b []byte) {
96	base.ExtensionMap()[id] = Extension{enc: b}
97}
98
99// isExtensionField returns true iff the given field number is in an extension range.
100func isExtensionField(pb extendableProto, field int32) bool {
101	for _, er := range pb.ExtensionRangeArray() {
102		if er.Start <= field && field <= er.End {
103			return true
104		}
105	}
106	return false
107}
108
109// checkExtensionTypes checks that the given extension is valid for pb.
110func checkExtensionTypes(pb extendableProto, extension *ExtensionDesc) error {
111	// Check the extended type.
112	if a, b := reflect.TypeOf(pb), reflect.TypeOf(extension.ExtendedType); a != b {
113		return errors.New("proto: bad extended type; " + b.String() + " does not extend " + a.String())
114	}
115	// Check the range.
116	if !isExtensionField(pb, extension.Field) {
117		return errors.New("proto: bad extension number; not in declared ranges")
118	}
119	return nil
120}
121
122// extPropKey is sufficient to uniquely identify an extension.
123type extPropKey struct {
124	base  reflect.Type
125	field int32
126}
127
128var extProp = struct {
129	sync.RWMutex
130	m map[extPropKey]*Properties
131}{
132	m: make(map[extPropKey]*Properties),
133}
134
135func extensionProperties(ed *ExtensionDesc) *Properties {
136	key := extPropKey{base: reflect.TypeOf(ed.ExtendedType), field: ed.Field}
137
138	extProp.RLock()
139	if prop, ok := extProp.m[key]; ok {
140		extProp.RUnlock()
141		return prop
142	}
143	extProp.RUnlock()
144
145	extProp.Lock()
146	defer extProp.Unlock()
147	// Check again.
148	if prop, ok := extProp.m[key]; ok {
149		return prop
150	}
151
152	prop := new(Properties)
153	prop.Init(reflect.TypeOf(ed.ExtensionType), "unknown_name", ed.Tag, nil)
154	extProp.m[key] = prop
155	return prop
156}
157
158// encodeExtensionMap encodes any unmarshaled (unencoded) extensions in m.
159func encodeExtensionMap(m map[int32]Extension) error {
160	for k, e := range m {
161		if e.value == nil || e.desc == nil {
162			// Extension is only in its encoded form.
163			continue
164		}
165
166		// We don't skip extensions that have an encoded form set,
167		// because the extension value may have been mutated after
168		// the last time this function was called.
169
170		et := reflect.TypeOf(e.desc.ExtensionType)
171		props := extensionProperties(e.desc)
172
173		p := NewBuffer(nil)
174		// If e.value has type T, the encoder expects a *struct{ X T }.
175		// Pass a *T with a zero field and hope it all works out.
176		x := reflect.New(et)
177		x.Elem().Set(reflect.ValueOf(e.value))
178		if err := props.enc(p, props, toStructPointer(x)); err != nil {
179			return err
180		}
181		e.enc = p.buf
182		m[k] = e
183	}
184	return nil
185}
186
187func sizeExtensionMap(m map[int32]Extension) (n int) {
188	for _, e := range m {
189		if e.value == nil || e.desc == nil {
190			// Extension is only in its encoded form.
191			n += len(e.enc)
192			continue
193		}
194
195		// We don't skip extensions that have an encoded form set,
196		// because the extension value may have been mutated after
197		// the last time this function was called.
198
199		et := reflect.TypeOf(e.desc.ExtensionType)
200		props := extensionProperties(e.desc)
201
202		// If e.value has type T, the encoder expects a *struct{ X T }.
203		// Pass a *T with a zero field and hope it all works out.
204		x := reflect.New(et)
205		x.Elem().Set(reflect.ValueOf(e.value))
206		n += props.size(props, toStructPointer(x))
207	}
208	return
209}
210
211// HasExtension returns whether the given extension is present in pb.
212func HasExtension(pb extendableProto, extension *ExtensionDesc) bool {
213	// TODO: Check types, field numbers, etc.?
214	_, ok := pb.ExtensionMap()[extension.Field]
215	return ok
216}
217
218// ClearExtension removes the given extension from pb.
219func ClearExtension(pb extendableProto, extension *ExtensionDesc) {
220	// TODO: Check types, field numbers, etc.?
221	delete(pb.ExtensionMap(), extension.Field)
222}
223
224// GetExtension parses and returns the given extension of pb.
225// If the extension is not present and has no default value it returns ErrMissingExtension.
226func GetExtension(pb extendableProto, extension *ExtensionDesc) (interface{}, error) {
227	if err := checkExtensionTypes(pb, extension); err != nil {
228		return nil, err
229	}
230
231	emap := pb.ExtensionMap()
232	e, ok := emap[extension.Field]
233	if !ok {
234		// defaultExtensionValue returns the default value or
235		// ErrMissingExtension if there is no default.
236		return defaultExtensionValue(extension)
237	}
238
239	if e.value != nil {
240		// Already decoded. Check the descriptor, though.
241		if e.desc != extension {
242			// This shouldn't happen. If it does, it means that
243			// GetExtension was called twice with two different
244			// descriptors with the same field number.
245			return nil, errors.New("proto: descriptor conflict")
246		}
247		return e.value, nil
248	}
249
250	v, err := decodeExtension(e.enc, extension)
251	if err != nil {
252		return nil, err
253	}
254
255	// Remember the decoded version and drop the encoded version.
256	// That way it is safe to mutate what we return.
257	e.value = v
258	e.desc = extension
259	e.enc = nil
260	emap[extension.Field] = e
261	return e.value, nil
262}
263
264// defaultExtensionValue returns the default value for extension.
265// If no default for an extension is defined ErrMissingExtension is returned.
266func defaultExtensionValue(extension *ExtensionDesc) (interface{}, error) {
267	t := reflect.TypeOf(extension.ExtensionType)
268	props := extensionProperties(extension)
269
270	sf, _, err := fieldDefault(t, props)
271	if err != nil {
272		return nil, err
273	}
274
275	if sf == nil || sf.value == nil {
276		// There is no default value.
277		return nil, ErrMissingExtension
278	}
279
280	if t.Kind() != reflect.Ptr {
281		// We do not need to return a Ptr, we can directly return sf.value.
282		return sf.value, nil
283	}
284
285	// We need to return an interface{} that is a pointer to sf.value.
286	value := reflect.New(t).Elem()
287	value.Set(reflect.New(value.Type().Elem()))
288	if sf.kind == reflect.Int32 {
289		// We may have an int32 or an enum, but the underlying data is int32.
290		// Since we can't set an int32 into a non int32 reflect.value directly
291		// set it as a int32.
292		value.Elem().SetInt(int64(sf.value.(int32)))
293	} else {
294		value.Elem().Set(reflect.ValueOf(sf.value))
295	}
296	return value.Interface(), nil
297}
298
299// decodeExtension decodes an extension encoded in b.
300func decodeExtension(b []byte, extension *ExtensionDesc) (interface{}, error) {
301	o := NewBuffer(b)
302
303	t := reflect.TypeOf(extension.ExtensionType)
304
305	props := extensionProperties(extension)
306
307	// t is a pointer to a struct, pointer to basic type or a slice.
308	// Allocate a "field" to store the pointer/slice itself; the
309	// pointer/slice will be stored here. We pass
310	// the address of this field to props.dec.
311	// This passes a zero field and a *t and lets props.dec
312	// interpret it as a *struct{ x t }.
313	value := reflect.New(t).Elem()
314
315	for {
316		// Discard wire type and field number varint. It isn't needed.
317		if _, err := o.DecodeVarint(); err != nil {
318			return nil, err
319		}
320
321		if err := props.dec(o, props, toStructPointer(value.Addr())); err != nil {
322			return nil, err
323		}
324
325		if o.index >= len(o.buf) {
326			break
327		}
328	}
329	return value.Interface(), nil
330}
331
332// GetExtensions returns a slice of the extensions present in pb that are also listed in es.
333// The returned slice has the same length as es; missing extensions will appear as nil elements.
334func GetExtensions(pb Message, es []*ExtensionDesc) (extensions []interface{}, err error) {
335	epb, ok := pb.(extendableProto)
336	if !ok {
337		err = errors.New("proto: not an extendable proto")
338		return
339	}
340	extensions = make([]interface{}, len(es))
341	for i, e := range es {
342		extensions[i], err = GetExtension(epb, e)
343		if err == ErrMissingExtension {
344			err = nil
345		}
346		if err != nil {
347			return
348		}
349	}
350	return
351}
352
353// SetExtension sets the specified extension of pb to the specified value.
354func SetExtension(pb extendableProto, extension *ExtensionDesc, value interface{}) error {
355	if err := checkExtensionTypes(pb, extension); err != nil {
356		return err
357	}
358	typ := reflect.TypeOf(extension.ExtensionType)
359	if typ != reflect.TypeOf(value) {
360		return errors.New("proto: bad extension value type")
361	}
362	// nil extension values need to be caught early, because the
363	// encoder can't distinguish an ErrNil due to a nil extension
364	// from an ErrNil due to a missing field. Extensions are
365	// always optional, so the encoder would just swallow the error
366	// and drop all the extensions from the encoded message.
367	if reflect.ValueOf(value).IsNil() {
368		return fmt.Errorf("proto: SetExtension called with nil value of type %T", value)
369	}
370
371	pb.ExtensionMap()[extension.Field] = Extension{desc: extension, value: value}
372	return nil
373}
374
375// A global registry of extensions.
376// The generated code will register the generated descriptors by calling RegisterExtension.
377
378var extensionMaps = make(map[reflect.Type]map[int32]*ExtensionDesc)
379
380// RegisterExtension is called from the generated code.
381func RegisterExtension(desc *ExtensionDesc) {
382	st := reflect.TypeOf(desc.ExtendedType).Elem()
383	m := extensionMaps[st]
384	if m == nil {
385		m = make(map[int32]*ExtensionDesc)
386		extensionMaps[st] = m
387	}
388	if _, ok := m[desc.Field]; ok {
389		panic("proto: duplicate extension registered: " + st.String() + " " + strconv.Itoa(int(desc.Field)))
390	}
391	m[desc.Field] = desc
392}
393
394// RegisteredExtensions returns a map of the registered extensions of a
395// protocol buffer struct, indexed by the extension number.
396// The argument pb should be a nil pointer to the struct type.
397func RegisteredExtensions(pb Message) map[int32]*ExtensionDesc {
398	return extensionMaps[reflect.TypeOf(pb).Elem()]
399}
400