1// Go support for Protocol Buffers - Google's data interchange format 2// 3// Copyright 2010 The Go Authors. All rights reserved. 4// https://github.com/golang/protobuf 5// 6// Redistribution and use in source and binary forms, with or without 7// modification, are permitted provided that the following conditions are 8// met: 9// 10// * Redistributions of source code must retain the above copyright 11// notice, this list of conditions and the following disclaimer. 12// * Redistributions in binary form must reproduce the above 13// copyright notice, this list of conditions and the following disclaimer 14// in the documentation and/or other materials provided with the 15// distribution. 16// * Neither the name of Google Inc. nor the names of its 17// contributors may be used to endorse or promote products derived from 18// this software without specific prior written permission. 19// 20// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 21// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 22// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 23// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 24// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 25// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 26// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 27// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 28// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 29// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 30// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 31 32package proto 33 34/* 35 * Types and routines for supporting protocol buffer extensions. 36 */ 37 38import ( 39 "errors" 40 "fmt" 41 "io" 42 "reflect" 43 "strconv" 44 "sync" 45) 46 47// ErrMissingExtension is the error returned by GetExtension if the named extension is not in the message. 48var ErrMissingExtension = errors.New("proto: missing extension") 49 50// ExtensionRange represents a range of message extensions for a protocol buffer. 51// Used in code generated by the protocol compiler. 52type ExtensionRange struct { 53 Start, End int32 // both inclusive 54} 55 56// extendableProto is an interface implemented by any protocol buffer generated by the current 57// proto compiler that may be extended. 58type extendableProto interface { 59 Message 60 ExtensionRangeArray() []ExtensionRange 61 extensionsWrite() map[int32]Extension 62 extensionsRead() (map[int32]Extension, sync.Locker) 63} 64 65// extendableProtoV1 is an interface implemented by a protocol buffer generated by the previous 66// version of the proto compiler that may be extended. 67type extendableProtoV1 interface { 68 Message 69 ExtensionRangeArray() []ExtensionRange 70 ExtensionMap() map[int32]Extension 71} 72 73// extensionAdapter is a wrapper around extendableProtoV1 that implements extendableProto. 74type extensionAdapter struct { 75 extendableProtoV1 76} 77 78func (e extensionAdapter) extensionsWrite() map[int32]Extension { 79 return e.ExtensionMap() 80} 81 82func (e extensionAdapter) extensionsRead() (map[int32]Extension, sync.Locker) { 83 return e.ExtensionMap(), notLocker{} 84} 85 86// notLocker is a sync.Locker whose Lock and Unlock methods are nops. 87type notLocker struct{} 88 89func (n notLocker) Lock() {} 90func (n notLocker) Unlock() {} 91 92// extendable returns the extendableProto interface for the given generated proto message. 93// If the proto message has the old extension format, it returns a wrapper that implements 94// the extendableProto interface. 95func extendable(p interface{}) (extendableProto, error) { 96 switch p := p.(type) { 97 case extendableProto: 98 if isNilPtr(p) { 99 return nil, fmt.Errorf("proto: nil %T is not extendable", p) 100 } 101 return p, nil 102 case extendableProtoV1: 103 if isNilPtr(p) { 104 return nil, fmt.Errorf("proto: nil %T is not extendable", p) 105 } 106 return extensionAdapter{p}, nil 107 case extensionsBytes: 108 return slowExtensionAdapter{p}, nil 109 } 110 // Don't allocate a specific error containing %T: 111 // this is the hot path for Clone and MarshalText. 112 return nil, errNotExtendable 113} 114 115var errNotExtendable = errors.New("proto: not an extendable proto.Message") 116 117func isNilPtr(x interface{}) bool { 118 v := reflect.ValueOf(x) 119 return v.Kind() == reflect.Ptr && v.IsNil() 120} 121 122// XXX_InternalExtensions is an internal representation of proto extensions. 123// 124// Each generated message struct type embeds an anonymous XXX_InternalExtensions field, 125// thus gaining the unexported 'extensions' method, which can be called only from the proto package. 126// 127// The methods of XXX_InternalExtensions are not concurrency safe in general, 128// but calls to logically read-only methods such as has and get may be executed concurrently. 129type XXX_InternalExtensions struct { 130 // The struct must be indirect so that if a user inadvertently copies a 131 // generated message and its embedded XXX_InternalExtensions, they 132 // avoid the mayhem of a copied mutex. 133 // 134 // The mutex serializes all logically read-only operations to p.extensionMap. 135 // It is up to the client to ensure that write operations to p.extensionMap are 136 // mutually exclusive with other accesses. 137 p *struct { 138 mu sync.Mutex 139 extensionMap map[int32]Extension 140 } 141} 142 143// extensionsWrite returns the extension map, creating it on first use. 144func (e *XXX_InternalExtensions) extensionsWrite() map[int32]Extension { 145 if e.p == nil { 146 e.p = new(struct { 147 mu sync.Mutex 148 extensionMap map[int32]Extension 149 }) 150 e.p.extensionMap = make(map[int32]Extension) 151 } 152 return e.p.extensionMap 153} 154 155// extensionsRead returns the extensions map for read-only use. It may be nil. 156// The caller must hold the returned mutex's lock when accessing Elements within the map. 157func (e *XXX_InternalExtensions) extensionsRead() (map[int32]Extension, sync.Locker) { 158 if e.p == nil { 159 return nil, nil 160 } 161 return e.p.extensionMap, &e.p.mu 162} 163 164// ExtensionDesc represents an extension specification. 165// Used in generated code from the protocol compiler. 166type ExtensionDesc struct { 167 ExtendedType Message // nil pointer to the type that is being extended 168 ExtensionType interface{} // nil pointer to the extension type 169 Field int32 // field number 170 Name string // fully-qualified name of extension, for text formatting 171 Tag string // protobuf tag style 172 Filename string // name of the file in which the extension is defined 173} 174 175func (ed *ExtensionDesc) repeated() bool { 176 t := reflect.TypeOf(ed.ExtensionType) 177 return t.Kind() == reflect.Slice && t.Elem().Kind() != reflect.Uint8 178} 179 180// Extension represents an extension in a message. 181type Extension struct { 182 // When an extension is stored in a message using SetExtension 183 // only desc and value are set. When the message is marshaled 184 // enc will be set to the encoded form of the message. 185 // 186 // When a message is unmarshaled and contains extensions, each 187 // extension will have only enc set. When such an extension is 188 // accessed using GetExtension (or GetExtensions) desc and value 189 // will be set. 190 desc *ExtensionDesc 191 value interface{} 192 enc []byte 193} 194 195// SetRawExtension is for testing only. 196func SetRawExtension(base Message, id int32, b []byte) { 197 if ebase, ok := base.(extensionsBytes); ok { 198 clearExtension(base, id) 199 ext := ebase.GetExtensions() 200 *ext = append(*ext, b...) 201 return 202 } 203 epb, err := extendable(base) 204 if err != nil { 205 return 206 } 207 extmap := epb.extensionsWrite() 208 extmap[id] = Extension{enc: b} 209} 210 211// isExtensionField returns true iff the given field number is in an extension range. 212func isExtensionField(pb extendableProto, field int32) bool { 213 for _, er := range pb.ExtensionRangeArray() { 214 if er.Start <= field && field <= er.End { 215 return true 216 } 217 } 218 return false 219} 220 221// checkExtensionTypes checks that the given extension is valid for pb. 222func checkExtensionTypes(pb extendableProto, extension *ExtensionDesc) error { 223 var pbi interface{} = pb 224 // Check the extended type. 225 if ea, ok := pbi.(extensionAdapter); ok { 226 pbi = ea.extendableProtoV1 227 } 228 if ea, ok := pbi.(slowExtensionAdapter); ok { 229 pbi = ea.extensionsBytes 230 } 231 if a, b := reflect.TypeOf(pbi), reflect.TypeOf(extension.ExtendedType); a != b { 232 return fmt.Errorf("proto: bad extended type; %v does not extend %v", b, a) 233 } 234 // Check the range. 235 if !isExtensionField(pb, extension.Field) { 236 return errors.New("proto: bad extension number; not in declared ranges") 237 } 238 return nil 239} 240 241// extPropKey is sufficient to uniquely identify an extension. 242type extPropKey struct { 243 base reflect.Type 244 field int32 245} 246 247var extProp = struct { 248 sync.RWMutex 249 m map[extPropKey]*Properties 250}{ 251 m: make(map[extPropKey]*Properties), 252} 253 254func extensionProperties(ed *ExtensionDesc) *Properties { 255 key := extPropKey{base: reflect.TypeOf(ed.ExtendedType), field: ed.Field} 256 257 extProp.RLock() 258 if prop, ok := extProp.m[key]; ok { 259 extProp.RUnlock() 260 return prop 261 } 262 extProp.RUnlock() 263 264 extProp.Lock() 265 defer extProp.Unlock() 266 // Check again. 267 if prop, ok := extProp.m[key]; ok { 268 return prop 269 } 270 271 prop := new(Properties) 272 prop.Init(reflect.TypeOf(ed.ExtensionType), "unknown_name", ed.Tag, nil) 273 extProp.m[key] = prop 274 return prop 275} 276 277// HasExtension returns whether the given extension is present in pb. 278func HasExtension(pb Message, extension *ExtensionDesc) bool { 279 if epb, doki := pb.(extensionsBytes); doki { 280 ext := epb.GetExtensions() 281 buf := *ext 282 o := 0 283 for o < len(buf) { 284 tag, n := DecodeVarint(buf[o:]) 285 fieldNum := int32(tag >> 3) 286 if int32(fieldNum) == extension.Field { 287 return true 288 } 289 wireType := int(tag & 0x7) 290 o += n 291 l, err := size(buf[o:], wireType) 292 if err != nil { 293 return false 294 } 295 o += l 296 } 297 return false 298 } 299 // TODO: Check types, field numbers, etc.? 300 epb, err := extendable(pb) 301 if err != nil { 302 return false 303 } 304 extmap, mu := epb.extensionsRead() 305 if extmap == nil { 306 return false 307 } 308 mu.Lock() 309 _, ok := extmap[extension.Field] 310 mu.Unlock() 311 return ok 312} 313 314// ClearExtension removes the given extension from pb. 315func ClearExtension(pb Message, extension *ExtensionDesc) { 316 clearExtension(pb, extension.Field) 317} 318 319func clearExtension(pb Message, fieldNum int32) { 320 if epb, ok := pb.(extensionsBytes); ok { 321 offset := 0 322 for offset != -1 { 323 offset = deleteExtension(epb, fieldNum, offset) 324 } 325 return 326 } 327 epb, err := extendable(pb) 328 if err != nil { 329 return 330 } 331 // TODO: Check types, field numbers, etc.? 332 extmap := epb.extensionsWrite() 333 delete(extmap, fieldNum) 334} 335 336// GetExtension retrieves a proto2 extended field from pb. 337// 338// If the descriptor is type complete (i.e., ExtensionDesc.ExtensionType is non-nil), 339// then GetExtension parses the encoded field and returns a Go value of the specified type. 340// If the field is not present, then the default value is returned (if one is specified), 341// otherwise ErrMissingExtension is reported. 342// 343// If the descriptor is not type complete (i.e., ExtensionDesc.ExtensionType is nil), 344// then GetExtension returns the raw encoded bytes of the field extension. 345func GetExtension(pb Message, extension *ExtensionDesc) (interface{}, error) { 346 if epb, doki := pb.(extensionsBytes); doki { 347 ext := epb.GetExtensions() 348 return decodeExtensionFromBytes(extension, *ext) 349 } 350 351 epb, err := extendable(pb) 352 if err != nil { 353 return nil, err 354 } 355 356 if extension.ExtendedType != nil { 357 // can only check type if this is a complete descriptor 358 if cerr := checkExtensionTypes(epb, extension); cerr != nil { 359 return nil, cerr 360 } 361 } 362 363 emap, mu := epb.extensionsRead() 364 if emap == nil { 365 return defaultExtensionValue(extension) 366 } 367 mu.Lock() 368 defer mu.Unlock() 369 e, ok := emap[extension.Field] 370 if !ok { 371 // defaultExtensionValue returns the default value or 372 // ErrMissingExtension if there is no default. 373 return defaultExtensionValue(extension) 374 } 375 376 if e.value != nil { 377 // Already decoded. Check the descriptor, though. 378 if e.desc != extension { 379 // This shouldn't happen. If it does, it means that 380 // GetExtension was called twice with two different 381 // descriptors with the same field number. 382 return nil, errors.New("proto: descriptor conflict") 383 } 384 return e.value, nil 385 } 386 387 if extension.ExtensionType == nil { 388 // incomplete descriptor 389 return e.enc, nil 390 } 391 392 v, err := decodeExtension(e.enc, extension) 393 if err != nil { 394 return nil, err 395 } 396 397 // Remember the decoded version and drop the encoded version. 398 // That way it is safe to mutate what we return. 399 e.value = v 400 e.desc = extension 401 e.enc = nil 402 emap[extension.Field] = e 403 return e.value, nil 404} 405 406// defaultExtensionValue returns the default value for extension. 407// If no default for an extension is defined ErrMissingExtension is returned. 408func defaultExtensionValue(extension *ExtensionDesc) (interface{}, error) { 409 if extension.ExtensionType == nil { 410 // incomplete descriptor, so no default 411 return nil, ErrMissingExtension 412 } 413 414 t := reflect.TypeOf(extension.ExtensionType) 415 props := extensionProperties(extension) 416 417 sf, _, err := fieldDefault(t, props) 418 if err != nil { 419 return nil, err 420 } 421 422 if sf == nil || sf.value == nil { 423 // There is no default value. 424 return nil, ErrMissingExtension 425 } 426 427 if t.Kind() != reflect.Ptr { 428 // We do not need to return a Ptr, we can directly return sf.value. 429 return sf.value, nil 430 } 431 432 // We need to return an interface{} that is a pointer to sf.value. 433 value := reflect.New(t).Elem() 434 value.Set(reflect.New(value.Type().Elem())) 435 if sf.kind == reflect.Int32 { 436 // We may have an int32 or an enum, but the underlying data is int32. 437 // Since we can't set an int32 into a non int32 reflect.value directly 438 // set it as a int32. 439 value.Elem().SetInt(int64(sf.value.(int32))) 440 } else { 441 value.Elem().Set(reflect.ValueOf(sf.value)) 442 } 443 return value.Interface(), nil 444} 445 446// decodeExtension decodes an extension encoded in b. 447func decodeExtension(b []byte, extension *ExtensionDesc) (interface{}, error) { 448 t := reflect.TypeOf(extension.ExtensionType) 449 unmarshal := typeUnmarshaler(t, extension.Tag) 450 451 // t is a pointer to a struct, pointer to basic type or a slice. 452 // Allocate space to store the pointer/slice. 453 value := reflect.New(t).Elem() 454 455 var err error 456 for { 457 x, n := decodeVarint(b) 458 if n == 0 { 459 return nil, io.ErrUnexpectedEOF 460 } 461 b = b[n:] 462 wire := int(x) & 7 463 464 b, err = unmarshal(b, valToPointer(value.Addr()), wire) 465 if err != nil { 466 return nil, err 467 } 468 469 if len(b) == 0 { 470 break 471 } 472 } 473 return value.Interface(), nil 474} 475 476// GetExtensions returns a slice of the extensions present in pb that are also listed in es. 477// The returned slice has the same length as es; missing extensions will appear as nil elements. 478func GetExtensions(pb Message, es []*ExtensionDesc) (extensions []interface{}, err error) { 479 epb, err := extendable(pb) 480 if err != nil { 481 return nil, err 482 } 483 extensions = make([]interface{}, len(es)) 484 for i, e := range es { 485 extensions[i], err = GetExtension(epb, e) 486 if err == ErrMissingExtension { 487 err = nil 488 } 489 if err != nil { 490 return 491 } 492 } 493 return 494} 495 496// ExtensionDescs returns a new slice containing pb's extension descriptors, in undefined order. 497// For non-registered extensions, ExtensionDescs returns an incomplete descriptor containing 498// just the Field field, which defines the extension's field number. 499func ExtensionDescs(pb Message) ([]*ExtensionDesc, error) { 500 epb, err := extendable(pb) 501 if err != nil { 502 return nil, err 503 } 504 registeredExtensions := RegisteredExtensions(pb) 505 506 emap, mu := epb.extensionsRead() 507 if emap == nil { 508 return nil, nil 509 } 510 mu.Lock() 511 defer mu.Unlock() 512 extensions := make([]*ExtensionDesc, 0, len(emap)) 513 for extid, e := range emap { 514 desc := e.desc 515 if desc == nil { 516 desc = registeredExtensions[extid] 517 if desc == nil { 518 desc = &ExtensionDesc{Field: extid} 519 } 520 } 521 522 extensions = append(extensions, desc) 523 } 524 return extensions, nil 525} 526 527// SetExtension sets the specified extension of pb to the specified value. 528func SetExtension(pb Message, extension *ExtensionDesc, value interface{}) error { 529 if epb, ok := pb.(extensionsBytes); ok { 530 newb, err := encodeExtension(extension, value) 531 if err != nil { 532 return err 533 } 534 bb := epb.GetExtensions() 535 *bb = append(*bb, newb...) 536 return nil 537 } 538 epb, err := extendable(pb) 539 if err != nil { 540 return err 541 } 542 if err := checkExtensionTypes(epb, extension); err != nil { 543 return err 544 } 545 typ := reflect.TypeOf(extension.ExtensionType) 546 if typ != reflect.TypeOf(value) { 547 return fmt.Errorf("proto: bad extension value type. got: %T, want: %T", value, extension.ExtensionType) 548 } 549 // nil extension values need to be caught early, because the 550 // encoder can't distinguish an ErrNil due to a nil extension 551 // from an ErrNil due to a missing field. Extensions are 552 // always optional, so the encoder would just swallow the error 553 // and drop all the extensions from the encoded message. 554 if reflect.ValueOf(value).IsNil() { 555 return fmt.Errorf("proto: SetExtension called with nil value of type %T", value) 556 } 557 558 extmap := epb.extensionsWrite() 559 extmap[extension.Field] = Extension{desc: extension, value: value} 560 return nil 561} 562 563// ClearAllExtensions clears all extensions from pb. 564func ClearAllExtensions(pb Message) { 565 if epb, doki := pb.(extensionsBytes); doki { 566 ext := epb.GetExtensions() 567 *ext = []byte{} 568 return 569 } 570 epb, err := extendable(pb) 571 if err != nil { 572 return 573 } 574 m := epb.extensionsWrite() 575 for k := range m { 576 delete(m, k) 577 } 578} 579 580// A global registry of extensions. 581// The generated code will register the generated descriptors by calling RegisterExtension. 582 583var extensionMaps = make(map[reflect.Type]map[int32]*ExtensionDesc) 584 585// RegisterExtension is called from the generated code. 586func RegisterExtension(desc *ExtensionDesc) { 587 st := reflect.TypeOf(desc.ExtendedType).Elem() 588 m := extensionMaps[st] 589 if m == nil { 590 m = make(map[int32]*ExtensionDesc) 591 extensionMaps[st] = m 592 } 593 if _, ok := m[desc.Field]; ok { 594 panic("proto: duplicate extension registered: " + st.String() + " " + strconv.Itoa(int(desc.Field))) 595 } 596 m[desc.Field] = desc 597} 598 599// RegisteredExtensions returns a map of the registered extensions of a 600// protocol buffer struct, indexed by the extension number. 601// The argument pb should be a nil pointer to the struct type. 602func RegisteredExtensions(pb Message) map[int32]*ExtensionDesc { 603 return extensionMaps[reflect.TypeOf(pb).Elem()] 604} 605