1// Go support for Protocol Buffers - Google's data interchange format
2//
3// Copyright 2010 The Go Authors.  All rights reserved.
4// https://github.com/golang/protobuf
5//
6// Redistribution and use in source and binary forms, with or without
7// modification, are permitted provided that the following conditions are
8// met:
9//
10//     * Redistributions of source code must retain the above copyright
11// notice, this list of conditions and the following disclaimer.
12//     * Redistributions in binary form must reproduce the above
13// copyright notice, this list of conditions and the following disclaimer
14// in the documentation and/or other materials provided with the
15// distribution.
16//     * Neither the name of Google Inc. nor the names of its
17// contributors may be used to endorse or promote products derived from
18// this software without specific prior written permission.
19//
20// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
21// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
22// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
23// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
24// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
25// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
26// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
27// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
28// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
29// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
30// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
31
32package proto
33
34/*
35 * Types and routines for supporting protocol buffer extensions.
36 */
37
38import (
39	"errors"
40	"fmt"
41	"reflect"
42	"strconv"
43	"sync"
44)
45
46// ErrMissingExtension is the error returned by GetExtension if the named extension is not in the message.
47var ErrMissingExtension = errors.New("proto: missing extension")
48
49// ExtensionRange represents a range of message extensions for a protocol buffer.
50// Used in code generated by the protocol compiler.
51type ExtensionRange struct {
52	Start, End int32 // both inclusive
53}
54
55// extendableProto is an interface implemented by any protocol buffer generated by the current
56// proto compiler that may be extended.
57type extendableProto interface {
58	Message
59	ExtensionRangeArray() []ExtensionRange
60	extensionsWrite() map[int32]Extension
61	extensionsRead() (map[int32]Extension, sync.Locker)
62}
63
64// extendableProtoV1 is an interface implemented by a protocol buffer generated by the previous
65// version of the proto compiler that may be extended.
66type extendableProtoV1 interface {
67	Message
68	ExtensionRangeArray() []ExtensionRange
69	ExtensionMap() map[int32]Extension
70}
71
72// extensionAdapter is a wrapper around extendableProtoV1 that implements extendableProto.
73type extensionAdapter struct {
74	extendableProtoV1
75}
76
77func (e extensionAdapter) extensionsWrite() map[int32]Extension {
78	return e.ExtensionMap()
79}
80
81func (e extensionAdapter) extensionsRead() (map[int32]Extension, sync.Locker) {
82	return e.ExtensionMap(), notLocker{}
83}
84
85// notLocker is a sync.Locker whose Lock and Unlock methods are nops.
86type notLocker struct{}
87
88func (n notLocker) Lock()   {}
89func (n notLocker) Unlock() {}
90
91// extendable returns the extendableProto interface for the given generated proto message.
92// If the proto message has the old extension format, it returns a wrapper that implements
93// the extendableProto interface.
94func extendable(p interface{}) (extendableProto, bool) {
95	if ep, ok := p.(extendableProto); ok {
96		return ep, ok
97	}
98	if ep, ok := p.(extendableProtoV1); ok {
99		return extensionAdapter{ep}, ok
100	}
101	return nil, false
102}
103
104// XXX_InternalExtensions is an internal representation of proto extensions.
105//
106// Each generated message struct type embeds an anonymous XXX_InternalExtensions field,
107// thus gaining the unexported 'extensions' method, which can be called only from the proto package.
108//
109// The methods of XXX_InternalExtensions are not concurrency safe in general,
110// but calls to logically read-only methods such as has and get may be executed concurrently.
111type XXX_InternalExtensions struct {
112	// The struct must be indirect so that if a user inadvertently copies a
113	// generated message and its embedded XXX_InternalExtensions, they
114	// avoid the mayhem of a copied mutex.
115	//
116	// The mutex serializes all logically read-only operations to p.extensionMap.
117	// It is up to the client to ensure that write operations to p.extensionMap are
118	// mutually exclusive with other accesses.
119	p *struct {
120		mu           sync.Mutex
121		extensionMap map[int32]Extension
122	}
123}
124
125// extensionsWrite returns the extension map, creating it on first use.
126func (e *XXX_InternalExtensions) extensionsWrite() map[int32]Extension {
127	if e.p == nil {
128		e.p = new(struct {
129			mu           sync.Mutex
130			extensionMap map[int32]Extension
131		})
132		e.p.extensionMap = make(map[int32]Extension)
133	}
134	return e.p.extensionMap
135}
136
137// extensionsRead returns the extensions map for read-only use.  It may be nil.
138// The caller must hold the returned mutex's lock when accessing Elements within the map.
139func (e *XXX_InternalExtensions) extensionsRead() (map[int32]Extension, sync.Locker) {
140	if e.p == nil {
141		return nil, nil
142	}
143	return e.p.extensionMap, &e.p.mu
144}
145
146var extendableProtoType = reflect.TypeOf((*extendableProto)(nil)).Elem()
147var extendableProtoV1Type = reflect.TypeOf((*extendableProtoV1)(nil)).Elem()
148
149// ExtensionDesc represents an extension specification.
150// Used in generated code from the protocol compiler.
151type ExtensionDesc struct {
152	ExtendedType  Message     // nil pointer to the type that is being extended
153	ExtensionType interface{} // nil pointer to the extension type
154	Field         int32       // field number
155	Name          string      // fully-qualified name of extension, for text formatting
156	Tag           string      // protobuf tag style
157	Filename      string      // name of the file in which the extension is defined
158}
159
160func (ed *ExtensionDesc) repeated() bool {
161	t := reflect.TypeOf(ed.ExtensionType)
162	return t.Kind() == reflect.Slice && t.Elem().Kind() != reflect.Uint8
163}
164
165// Extension represents an extension in a message.
166type Extension struct {
167	// When an extension is stored in a message using SetExtension
168	// only desc and value are set. When the message is marshaled
169	// enc will be set to the encoded form of the message.
170	//
171	// When a message is unmarshaled and contains extensions, each
172	// extension will have only enc set. When such an extension is
173	// accessed using GetExtension (or GetExtensions) desc and value
174	// will be set.
175	desc  *ExtensionDesc
176	value interface{}
177	enc   []byte
178}
179
180// SetRawExtension is for testing only.
181func SetRawExtension(base Message, id int32, b []byte) {
182	epb, ok := extendable(base)
183	if !ok {
184		return
185	}
186	extmap := epb.extensionsWrite()
187	extmap[id] = Extension{enc: b}
188}
189
190// isExtensionField returns true iff the given field number is in an extension range.
191func isExtensionField(pb extendableProto, field int32) bool {
192	for _, er := range pb.ExtensionRangeArray() {
193		if er.Start <= field && field <= er.End {
194			return true
195		}
196	}
197	return false
198}
199
200// checkExtensionTypes checks that the given extension is valid for pb.
201func checkExtensionTypes(pb extendableProto, extension *ExtensionDesc) error {
202	var pbi interface{} = pb
203	// Check the extended type.
204	if ea, ok := pbi.(extensionAdapter); ok {
205		pbi = ea.extendableProtoV1
206	}
207	if a, b := reflect.TypeOf(pbi), reflect.TypeOf(extension.ExtendedType); a != b {
208		return errors.New("proto: bad extended type; " + b.String() + " does not extend " + a.String())
209	}
210	// Check the range.
211	if !isExtensionField(pb, extension.Field) {
212		return errors.New("proto: bad extension number; not in declared ranges")
213	}
214	return nil
215}
216
217// extPropKey is sufficient to uniquely identify an extension.
218type extPropKey struct {
219	base  reflect.Type
220	field int32
221}
222
223var extProp = struct {
224	sync.RWMutex
225	m map[extPropKey]*Properties
226}{
227	m: make(map[extPropKey]*Properties),
228}
229
230func extensionProperties(ed *ExtensionDesc) *Properties {
231	key := extPropKey{base: reflect.TypeOf(ed.ExtendedType), field: ed.Field}
232
233	extProp.RLock()
234	if prop, ok := extProp.m[key]; ok {
235		extProp.RUnlock()
236		return prop
237	}
238	extProp.RUnlock()
239
240	extProp.Lock()
241	defer extProp.Unlock()
242	// Check again.
243	if prop, ok := extProp.m[key]; ok {
244		return prop
245	}
246
247	prop := new(Properties)
248	prop.Init(reflect.TypeOf(ed.ExtensionType), "unknown_name", ed.Tag, nil)
249	extProp.m[key] = prop
250	return prop
251}
252
253// encode encodes any unmarshaled (unencoded) extensions in e.
254func encodeExtensions(e *XXX_InternalExtensions) error {
255	m, mu := e.extensionsRead()
256	if m == nil {
257		return nil // fast path
258	}
259	mu.Lock()
260	defer mu.Unlock()
261	return encodeExtensionsMap(m)
262}
263
264// encode encodes any unmarshaled (unencoded) extensions in e.
265func encodeExtensionsMap(m map[int32]Extension) error {
266	for k, e := range m {
267		if e.value == nil || e.desc == nil {
268			// Extension is only in its encoded form.
269			continue
270		}
271
272		// We don't skip extensions that have an encoded form set,
273		// because the extension value may have been mutated after
274		// the last time this function was called.
275
276		et := reflect.TypeOf(e.desc.ExtensionType)
277		props := extensionProperties(e.desc)
278
279		p := NewBuffer(nil)
280		// If e.value has type T, the encoder expects a *struct{ X T }.
281		// Pass a *T with a zero field and hope it all works out.
282		x := reflect.New(et)
283		x.Elem().Set(reflect.ValueOf(e.value))
284		if err := props.enc(p, props, toStructPointer(x)); err != nil {
285			return err
286		}
287		e.enc = p.buf
288		m[k] = e
289	}
290	return nil
291}
292
293func extensionsSize(e *XXX_InternalExtensions) (n int) {
294	m, mu := e.extensionsRead()
295	if m == nil {
296		return 0
297	}
298	mu.Lock()
299	defer mu.Unlock()
300	return extensionsMapSize(m)
301}
302
303func extensionsMapSize(m map[int32]Extension) (n int) {
304	for _, e := range m {
305		if e.value == nil || e.desc == nil {
306			// Extension is only in its encoded form.
307			n += len(e.enc)
308			continue
309		}
310
311		// We don't skip extensions that have an encoded form set,
312		// because the extension value may have been mutated after
313		// the last time this function was called.
314
315		et := reflect.TypeOf(e.desc.ExtensionType)
316		props := extensionProperties(e.desc)
317
318		// If e.value has type T, the encoder expects a *struct{ X T }.
319		// Pass a *T with a zero field and hope it all works out.
320		x := reflect.New(et)
321		x.Elem().Set(reflect.ValueOf(e.value))
322		n += props.size(props, toStructPointer(x))
323	}
324	return
325}
326
327// HasExtension returns whether the given extension is present in pb.
328func HasExtension(pb Message, extension *ExtensionDesc) bool {
329	// TODO: Check types, field numbers, etc.?
330	epb, ok := extendable(pb)
331	if !ok {
332		return false
333	}
334	extmap, mu := epb.extensionsRead()
335	if extmap == nil {
336		return false
337	}
338	mu.Lock()
339	_, ok = extmap[extension.Field]
340	mu.Unlock()
341	return ok
342}
343
344// ClearExtension removes the given extension from pb.
345func ClearExtension(pb Message, extension *ExtensionDesc) {
346	epb, ok := extendable(pb)
347	if !ok {
348		return
349	}
350	// TODO: Check types, field numbers, etc.?
351	extmap := epb.extensionsWrite()
352	delete(extmap, extension.Field)
353}
354
355// GetExtension parses and returns the given extension of pb.
356// If the extension is not present and has no default value it returns ErrMissingExtension.
357func GetExtension(pb Message, extension *ExtensionDesc) (interface{}, error) {
358	epb, ok := extendable(pb)
359	if !ok {
360		return nil, errors.New("proto: not an extendable proto")
361	}
362
363	if err := checkExtensionTypes(epb, extension); err != nil {
364		return nil, err
365	}
366
367	emap, mu := epb.extensionsRead()
368	if emap == nil {
369		return defaultExtensionValue(extension)
370	}
371	mu.Lock()
372	defer mu.Unlock()
373	e, ok := emap[extension.Field]
374	if !ok {
375		// defaultExtensionValue returns the default value or
376		// ErrMissingExtension if there is no default.
377		return defaultExtensionValue(extension)
378	}
379
380	if e.value != nil {
381		// Already decoded. Check the descriptor, though.
382		if e.desc != extension {
383			// This shouldn't happen. If it does, it means that
384			// GetExtension was called twice with two different
385			// descriptors with the same field number.
386			return nil, errors.New("proto: descriptor conflict")
387		}
388		return e.value, nil
389	}
390
391	v, err := decodeExtension(e.enc, extension)
392	if err != nil {
393		return nil, err
394	}
395
396	// Remember the decoded version and drop the encoded version.
397	// That way it is safe to mutate what we return.
398	e.value = v
399	e.desc = extension
400	e.enc = nil
401	emap[extension.Field] = e
402	return e.value, nil
403}
404
405// defaultExtensionValue returns the default value for extension.
406// If no default for an extension is defined ErrMissingExtension is returned.
407func defaultExtensionValue(extension *ExtensionDesc) (interface{}, error) {
408	t := reflect.TypeOf(extension.ExtensionType)
409	props := extensionProperties(extension)
410
411	sf, _, err := fieldDefault(t, props)
412	if err != nil {
413		return nil, err
414	}
415
416	if sf == nil || sf.value == nil {
417		// There is no default value.
418		return nil, ErrMissingExtension
419	}
420
421	if t.Kind() != reflect.Ptr {
422		// We do not need to return a Ptr, we can directly return sf.value.
423		return sf.value, nil
424	}
425
426	// We need to return an interface{} that is a pointer to sf.value.
427	value := reflect.New(t).Elem()
428	value.Set(reflect.New(value.Type().Elem()))
429	if sf.kind == reflect.Int32 {
430		// We may have an int32 or an enum, but the underlying data is int32.
431		// Since we can't set an int32 into a non int32 reflect.value directly
432		// set it as a int32.
433		value.Elem().SetInt(int64(sf.value.(int32)))
434	} else {
435		value.Elem().Set(reflect.ValueOf(sf.value))
436	}
437	return value.Interface(), nil
438}
439
440// decodeExtension decodes an extension encoded in b.
441func decodeExtension(b []byte, extension *ExtensionDesc) (interface{}, error) {
442	o := NewBuffer(b)
443
444	t := reflect.TypeOf(extension.ExtensionType)
445
446	props := extensionProperties(extension)
447
448	// t is a pointer to a struct, pointer to basic type or a slice.
449	// Allocate a "field" to store the pointer/slice itself; the
450	// pointer/slice will be stored here. We pass
451	// the address of this field to props.dec.
452	// This passes a zero field and a *t and lets props.dec
453	// interpret it as a *struct{ x t }.
454	value := reflect.New(t).Elem()
455
456	for {
457		// Discard wire type and field number varint. It isn't needed.
458		if _, err := o.DecodeVarint(); err != nil {
459			return nil, err
460		}
461
462		if err := props.dec(o, props, toStructPointer(value.Addr())); err != nil {
463			return nil, err
464		}
465
466		if o.index >= len(o.buf) {
467			break
468		}
469	}
470	return value.Interface(), nil
471}
472
473// GetExtensions returns a slice of the extensions present in pb that are also listed in es.
474// The returned slice has the same length as es; missing extensions will appear as nil elements.
475func GetExtensions(pb Message, es []*ExtensionDesc) (extensions []interface{}, err error) {
476	epb, ok := extendable(pb)
477	if !ok {
478		return nil, errors.New("proto: not an extendable proto")
479	}
480	extensions = make([]interface{}, len(es))
481	for i, e := range es {
482		extensions[i], err = GetExtension(epb, e)
483		if err == ErrMissingExtension {
484			err = nil
485		}
486		if err != nil {
487			return
488		}
489	}
490	return
491}
492
493// ExtensionDescs returns a new slice containing pb's extension descriptors, in undefined order.
494// For non-registered extensions, ExtensionDescs returns an incomplete descriptor containing
495// just the Field field, which defines the extension's field number.
496func ExtensionDescs(pb Message) ([]*ExtensionDesc, error) {
497	epb, ok := extendable(pb)
498	if !ok {
499		return nil, fmt.Errorf("proto: %T is not an extendable proto.Message", pb)
500	}
501	registeredExtensions := RegisteredExtensions(pb)
502
503	emap, mu := epb.extensionsRead()
504	if emap == nil {
505		return nil, nil
506	}
507	mu.Lock()
508	defer mu.Unlock()
509	extensions := make([]*ExtensionDesc, 0, len(emap))
510	for extid, e := range emap {
511		desc := e.desc
512		if desc == nil {
513			desc = registeredExtensions[extid]
514			if desc == nil {
515				desc = &ExtensionDesc{Field: extid}
516			}
517		}
518
519		extensions = append(extensions, desc)
520	}
521	return extensions, nil
522}
523
524// SetExtension sets the specified extension of pb to the specified value.
525func SetExtension(pb Message, extension *ExtensionDesc, value interface{}) error {
526	epb, ok := extendable(pb)
527	if !ok {
528		return errors.New("proto: not an extendable proto")
529	}
530	if err := checkExtensionTypes(epb, extension); err != nil {
531		return err
532	}
533	typ := reflect.TypeOf(extension.ExtensionType)
534	if typ != reflect.TypeOf(value) {
535		return errors.New("proto: bad extension value type")
536	}
537	// nil extension values need to be caught early, because the
538	// encoder can't distinguish an ErrNil due to a nil extension
539	// from an ErrNil due to a missing field. Extensions are
540	// always optional, so the encoder would just swallow the error
541	// and drop all the extensions from the encoded message.
542	if reflect.ValueOf(value).IsNil() {
543		return fmt.Errorf("proto: SetExtension called with nil value of type %T", value)
544	}
545
546	extmap := epb.extensionsWrite()
547	extmap[extension.Field] = Extension{desc: extension, value: value}
548	return nil
549}
550
551// ClearAllExtensions clears all extensions from pb.
552func ClearAllExtensions(pb Message) {
553	epb, ok := extendable(pb)
554	if !ok {
555		return
556	}
557	m := epb.extensionsWrite()
558	for k := range m {
559		delete(m, k)
560	}
561}
562
563// A global registry of extensions.
564// The generated code will register the generated descriptors by calling RegisterExtension.
565
566var extensionMaps = make(map[reflect.Type]map[int32]*ExtensionDesc)
567
568// RegisterExtension is called from the generated code.
569func RegisterExtension(desc *ExtensionDesc) {
570	st := reflect.TypeOf(desc.ExtendedType).Elem()
571	m := extensionMaps[st]
572	if m == nil {
573		m = make(map[int32]*ExtensionDesc)
574		extensionMaps[st] = m
575	}
576	if _, ok := m[desc.Field]; ok {
577		panic("proto: duplicate extension registered: " + st.String() + " " + strconv.Itoa(int(desc.Field)))
578	}
579	m[desc.Field] = desc
580}
581
582// RegisteredExtensions returns a map of the registered extensions of a
583// protocol buffer struct, indexed by the extension number.
584// The argument pb should be a nil pointer to the struct type.
585func RegisteredExtensions(pb Message) map[int32]*ExtensionDesc {
586	return extensionMaps[reflect.TypeOf(pb).Elem()]
587}
588