1// Go support for Protocol Buffers - Google's data interchange format 2// 3// Copyright 2010 The Go Authors. All rights reserved. 4// https://github.com/golang/protobuf 5// 6// Redistribution and use in source and binary forms, with or without 7// modification, are permitted provided that the following conditions are 8// met: 9// 10// * Redistributions of source code must retain the above copyright 11// notice, this list of conditions and the following disclaimer. 12// * Redistributions in binary form must reproduce the above 13// copyright notice, this list of conditions and the following disclaimer 14// in the documentation and/or other materials provided with the 15// distribution. 16// * Neither the name of Google Inc. nor the names of its 17// contributors may be used to endorse or promote products derived from 18// this software without specific prior written permission. 19// 20// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 21// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 22// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 23// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 24// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 25// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 26// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 27// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 28// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 29// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 30// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 31 32package proto 33 34/* 35 * Types and routines for supporting protocol buffer extensions. 36 */ 37 38import ( 39 "errors" 40 "fmt" 41 "reflect" 42 "strconv" 43 "sync" 44) 45 46// ErrMissingExtension is the error returned by GetExtension if the named extension is not in the message. 47var ErrMissingExtension = errors.New("proto: missing extension") 48 49// ExtensionRange represents a range of message extensions for a protocol buffer. 50// Used in code generated by the protocol compiler. 51type ExtensionRange struct { 52 Start, End int32 // both inclusive 53} 54 55// extendableProto is an interface implemented by any protocol buffer that may be extended. 56type extendableProto interface { 57 Message 58 ExtensionRangeArray() []ExtensionRange 59 ExtensionMap() map[int32]Extension 60} 61 62var extendableProtoType = reflect.TypeOf((*extendableProto)(nil)).Elem() 63 64// ExtensionDesc represents an extension specification. 65// Used in generated code from the protocol compiler. 66type ExtensionDesc struct { 67 ExtendedType Message // nil pointer to the type that is being extended 68 ExtensionType interface{} // nil pointer to the extension type 69 Field int32 // field number 70 Name string // fully-qualified name of extension, for text formatting 71 Tag string // protobuf tag style 72} 73 74func (ed *ExtensionDesc) repeated() bool { 75 t := reflect.TypeOf(ed.ExtensionType) 76 return t.Kind() == reflect.Slice && t.Elem().Kind() != reflect.Uint8 77} 78 79// Extension represents an extension in a message. 80type Extension struct { 81 // When an extension is stored in a message using SetExtension 82 // only desc and value are set. When the message is marshaled 83 // enc will be set to the encoded form of the message. 84 // 85 // When a message is unmarshaled and contains extensions, each 86 // extension will have only enc set. When such an extension is 87 // accessed using GetExtension (or GetExtensions) desc and value 88 // will be set. 89 desc *ExtensionDesc 90 value interface{} 91 enc []byte 92} 93 94// SetRawExtension is for testing only. 95func SetRawExtension(base extendableProto, id int32, b []byte) { 96 base.ExtensionMap()[id] = Extension{enc: b} 97} 98 99// isExtensionField returns true iff the given field number is in an extension range. 100func isExtensionField(pb extendableProto, field int32) bool { 101 for _, er := range pb.ExtensionRangeArray() { 102 if er.Start <= field && field <= er.End { 103 return true 104 } 105 } 106 return false 107} 108 109// checkExtensionTypes checks that the given extension is valid for pb. 110func checkExtensionTypes(pb extendableProto, extension *ExtensionDesc) error { 111 // Check the extended type. 112 if a, b := reflect.TypeOf(pb), reflect.TypeOf(extension.ExtendedType); a != b { 113 return errors.New("proto: bad extended type; " + b.String() + " does not extend " + a.String()) 114 } 115 // Check the range. 116 if !isExtensionField(pb, extension.Field) { 117 return errors.New("proto: bad extension number; not in declared ranges") 118 } 119 return nil 120} 121 122// extPropKey is sufficient to uniquely identify an extension. 123type extPropKey struct { 124 base reflect.Type 125 field int32 126} 127 128var extProp = struct { 129 sync.RWMutex 130 m map[extPropKey]*Properties 131}{ 132 m: make(map[extPropKey]*Properties), 133} 134 135func extensionProperties(ed *ExtensionDesc) *Properties { 136 key := extPropKey{base: reflect.TypeOf(ed.ExtendedType), field: ed.Field} 137 138 extProp.RLock() 139 if prop, ok := extProp.m[key]; ok { 140 extProp.RUnlock() 141 return prop 142 } 143 extProp.RUnlock() 144 145 extProp.Lock() 146 defer extProp.Unlock() 147 // Check again. 148 if prop, ok := extProp.m[key]; ok { 149 return prop 150 } 151 152 prop := new(Properties) 153 prop.Init(reflect.TypeOf(ed.ExtensionType), "unknown_name", ed.Tag, nil) 154 extProp.m[key] = prop 155 return prop 156} 157 158// encodeExtensionMap encodes any unmarshaled (unencoded) extensions in m. 159func encodeExtensionMap(m map[int32]Extension) error { 160 for k, e := range m { 161 if e.value == nil || e.desc == nil { 162 // Extension is only in its encoded form. 163 continue 164 } 165 166 // We don't skip extensions that have an encoded form set, 167 // because the extension value may have been mutated after 168 // the last time this function was called. 169 170 et := reflect.TypeOf(e.desc.ExtensionType) 171 props := extensionProperties(e.desc) 172 173 p := NewBuffer(nil) 174 // If e.value has type T, the encoder expects a *struct{ X T }. 175 // Pass a *T with a zero field and hope it all works out. 176 x := reflect.New(et) 177 x.Elem().Set(reflect.ValueOf(e.value)) 178 if err := props.enc(p, props, toStructPointer(x)); err != nil { 179 return err 180 } 181 e.enc = p.buf 182 m[k] = e 183 } 184 return nil 185} 186 187func sizeExtensionMap(m map[int32]Extension) (n int) { 188 for _, e := range m { 189 if e.value == nil || e.desc == nil { 190 // Extension is only in its encoded form. 191 n += len(e.enc) 192 continue 193 } 194 195 // We don't skip extensions that have an encoded form set, 196 // because the extension value may have been mutated after 197 // the last time this function was called. 198 199 et := reflect.TypeOf(e.desc.ExtensionType) 200 props := extensionProperties(e.desc) 201 202 // If e.value has type T, the encoder expects a *struct{ X T }. 203 // Pass a *T with a zero field and hope it all works out. 204 x := reflect.New(et) 205 x.Elem().Set(reflect.ValueOf(e.value)) 206 n += props.size(props, toStructPointer(x)) 207 } 208 return 209} 210 211// HasExtension returns whether the given extension is present in pb. 212func HasExtension(pb extendableProto, extension *ExtensionDesc) bool { 213 // TODO: Check types, field numbers, etc.? 214 _, ok := pb.ExtensionMap()[extension.Field] 215 return ok 216} 217 218// ClearExtension removes the given extension from pb. 219func ClearExtension(pb extendableProto, extension *ExtensionDesc) { 220 // TODO: Check types, field numbers, etc.? 221 delete(pb.ExtensionMap(), extension.Field) 222} 223 224// GetExtension parses and returns the given extension of pb. 225// If the extension is not present and has no default value it returns ErrMissingExtension. 226func GetExtension(pb extendableProto, extension *ExtensionDesc) (interface{}, error) { 227 if err := checkExtensionTypes(pb, extension); err != nil { 228 return nil, err 229 } 230 231 emap := pb.ExtensionMap() 232 e, ok := emap[extension.Field] 233 if !ok { 234 // defaultExtensionValue returns the default value or 235 // ErrMissingExtension if there is no default. 236 return defaultExtensionValue(extension) 237 } 238 239 if e.value != nil { 240 // Already decoded. Check the descriptor, though. 241 if e.desc != extension { 242 // This shouldn't happen. If it does, it means that 243 // GetExtension was called twice with two different 244 // descriptors with the same field number. 245 return nil, errors.New("proto: descriptor conflict") 246 } 247 return e.value, nil 248 } 249 250 v, err := decodeExtension(e.enc, extension) 251 if err != nil { 252 return nil, err 253 } 254 255 // Remember the decoded version and drop the encoded version. 256 // That way it is safe to mutate what we return. 257 e.value = v 258 e.desc = extension 259 e.enc = nil 260 emap[extension.Field] = e 261 return e.value, nil 262} 263 264// defaultExtensionValue returns the default value for extension. 265// If no default for an extension is defined ErrMissingExtension is returned. 266func defaultExtensionValue(extension *ExtensionDesc) (interface{}, error) { 267 t := reflect.TypeOf(extension.ExtensionType) 268 props := extensionProperties(extension) 269 270 sf, _, err := fieldDefault(t, props) 271 if err != nil { 272 return nil, err 273 } 274 275 if sf == nil || sf.value == nil { 276 // There is no default value. 277 return nil, ErrMissingExtension 278 } 279 280 if t.Kind() != reflect.Ptr { 281 // We do not need to return a Ptr, we can directly return sf.value. 282 return sf.value, nil 283 } 284 285 // We need to return an interface{} that is a pointer to sf.value. 286 value := reflect.New(t).Elem() 287 value.Set(reflect.New(value.Type().Elem())) 288 if sf.kind == reflect.Int32 { 289 // We may have an int32 or an enum, but the underlying data is int32. 290 // Since we can't set an int32 into a non int32 reflect.value directly 291 // set it as a int32. 292 value.Elem().SetInt(int64(sf.value.(int32))) 293 } else { 294 value.Elem().Set(reflect.ValueOf(sf.value)) 295 } 296 return value.Interface(), nil 297} 298 299// decodeExtension decodes an extension encoded in b. 300func decodeExtension(b []byte, extension *ExtensionDesc) (interface{}, error) { 301 o := NewBuffer(b) 302 303 t := reflect.TypeOf(extension.ExtensionType) 304 305 props := extensionProperties(extension) 306 307 // t is a pointer to a struct, pointer to basic type or a slice. 308 // Allocate a "field" to store the pointer/slice itself; the 309 // pointer/slice will be stored here. We pass 310 // the address of this field to props.dec. 311 // This passes a zero field and a *t and lets props.dec 312 // interpret it as a *struct{ x t }. 313 value := reflect.New(t).Elem() 314 315 for { 316 // Discard wire type and field number varint. It isn't needed. 317 if _, err := o.DecodeVarint(); err != nil { 318 return nil, err 319 } 320 321 if err := props.dec(o, props, toStructPointer(value.Addr())); err != nil { 322 return nil, err 323 } 324 325 if o.index >= len(o.buf) { 326 break 327 } 328 } 329 return value.Interface(), nil 330} 331 332// GetExtensions returns a slice of the extensions present in pb that are also listed in es. 333// The returned slice has the same length as es; missing extensions will appear as nil elements. 334func GetExtensions(pb Message, es []*ExtensionDesc) (extensions []interface{}, err error) { 335 epb, ok := pb.(extendableProto) 336 if !ok { 337 err = errors.New("proto: not an extendable proto") 338 return 339 } 340 extensions = make([]interface{}, len(es)) 341 for i, e := range es { 342 extensions[i], err = GetExtension(epb, e) 343 if err == ErrMissingExtension { 344 err = nil 345 } 346 if err != nil { 347 return 348 } 349 } 350 return 351} 352 353// SetExtension sets the specified extension of pb to the specified value. 354func SetExtension(pb extendableProto, extension *ExtensionDesc, value interface{}) error { 355 if err := checkExtensionTypes(pb, extension); err != nil { 356 return err 357 } 358 typ := reflect.TypeOf(extension.ExtensionType) 359 if typ != reflect.TypeOf(value) { 360 return errors.New("proto: bad extension value type") 361 } 362 // nil extension values need to be caught early, because the 363 // encoder can't distinguish an ErrNil due to a nil extension 364 // from an ErrNil due to a missing field. Extensions are 365 // always optional, so the encoder would just swallow the error 366 // and drop all the extensions from the encoded message. 367 if reflect.ValueOf(value).IsNil() { 368 return fmt.Errorf("proto: SetExtension called with nil value of type %T", value) 369 } 370 371 pb.ExtensionMap()[extension.Field] = Extension{desc: extension, value: value} 372 return nil 373} 374 375// A global registry of extensions. 376// The generated code will register the generated descriptors by calling RegisterExtension. 377 378var extensionMaps = make(map[reflect.Type]map[int32]*ExtensionDesc) 379 380// RegisterExtension is called from the generated code. 381func RegisterExtension(desc *ExtensionDesc) { 382 st := reflect.TypeOf(desc.ExtendedType).Elem() 383 m := extensionMaps[st] 384 if m == nil { 385 m = make(map[int32]*ExtensionDesc) 386 extensionMaps[st] = m 387 } 388 if _, ok := m[desc.Field]; ok { 389 panic("proto: duplicate extension registered: " + st.String() + " " + strconv.Itoa(int(desc.Field))) 390 } 391 m[desc.Field] = desc 392} 393 394// RegisteredExtensions returns a map of the registered extensions of a 395// protocol buffer struct, indexed by the extension number. 396// The argument pb should be a nil pointer to the struct type. 397func RegisteredExtensions(pb Message) map[int32]*ExtensionDesc { 398 return extensionMaps[reflect.TypeOf(pb).Elem()] 399} 400