1// Go support for Protocol Buffers - Google's data interchange format 2// 3// Copyright 2010 The Go Authors. All rights reserved. 4// https://github.com/golang/protobuf 5// 6// Redistribution and use in source and binary forms, with or without 7// modification, are permitted provided that the following conditions are 8// met: 9// 10// * Redistributions of source code must retain the above copyright 11// notice, this list of conditions and the following disclaimer. 12// * Redistributions in binary form must reproduce the above 13// copyright notice, this list of conditions and the following disclaimer 14// in the documentation and/or other materials provided with the 15// distribution. 16// * Neither the name of Google Inc. nor the names of its 17// contributors may be used to endorse or promote products derived from 18// this software without specific prior written permission. 19// 20// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 21// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 22// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 23// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 24// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 25// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 26// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 27// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 28// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 29// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 30// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 31 32package proto 33 34/* 35 * Types and routines for supporting protocol buffer extensions. 36 */ 37 38import ( 39 "errors" 40 "fmt" 41 "reflect" 42 "strconv" 43 "sync" 44) 45 46// ErrMissingExtension is the error returned by GetExtension if the named extension is not in the message. 47var ErrMissingExtension = errors.New("proto: missing extension") 48 49// ExtensionRange represents a range of message extensions for a protocol buffer. 50// Used in code generated by the protocol compiler. 51type ExtensionRange struct { 52 Start, End int32 // both inclusive 53} 54 55// extendableProto is an interface implemented by any protocol buffer that may be extended. 56type extendableProto interface { 57 Message 58 ExtensionRangeArray() []ExtensionRange 59 ExtensionMap() map[int32]Extension 60} 61 62var extendableProtoType = reflect.TypeOf((*extendableProto)(nil)).Elem() 63 64// ExtensionDesc represents an extension specification. 65// Used in generated code from the protocol compiler. 66type ExtensionDesc struct { 67 ExtendedType Message // nil pointer to the type that is being extended 68 ExtensionType interface{} // nil pointer to the extension type 69 Field int32 // field number 70 Name string // fully-qualified name of extension, for text formatting 71 Tag string // protobuf tag style 72} 73 74func (ed *ExtensionDesc) repeated() bool { 75 t := reflect.TypeOf(ed.ExtensionType) 76 return t.Kind() == reflect.Slice && t.Elem().Kind() != reflect.Uint8 77} 78 79// Extension represents an extension in a message. 80type Extension struct { 81 // When an extension is stored in a message using SetExtension 82 // only desc and value are set. When the message is marshaled 83 // enc will be set to the encoded form of the message. 84 // 85 // When a message is unmarshaled and contains extensions, each 86 // extension will have only enc set. When such an extension is 87 // accessed using GetExtension (or GetExtensions) desc and value 88 // will be set. 89 desc *ExtensionDesc 90 value interface{} 91 enc []byte 92} 93 94// SetRawExtension is for testing only. 95func SetRawExtension(base extendableProto, id int32, b []byte) { 96 base.ExtensionMap()[id] = Extension{enc: b} 97} 98 99// isExtensionField returns true iff the given field number is in an extension range. 100func isExtensionField(pb extendableProto, field int32) bool { 101 for _, er := range pb.ExtensionRangeArray() { 102 if er.Start <= field && field <= er.End { 103 return true 104 } 105 } 106 return false 107} 108 109// checkExtensionTypes checks that the given extension is valid for pb. 110func checkExtensionTypes(pb extendableProto, extension *ExtensionDesc) error { 111 // Check the extended type. 112 if a, b := reflect.TypeOf(pb), reflect.TypeOf(extension.ExtendedType); a != b { 113 return errors.New("proto: bad extended type; " + b.String() + " does not extend " + a.String()) 114 } 115 // Check the range. 116 if !isExtensionField(pb, extension.Field) { 117 return errors.New("proto: bad extension number; not in declared ranges") 118 } 119 return nil 120} 121 122// extPropKey is sufficient to uniquely identify an extension. 123type extPropKey struct { 124 base reflect.Type 125 field int32 126} 127 128var extProp = struct { 129 sync.RWMutex 130 m map[extPropKey]*Properties 131}{ 132 m: make(map[extPropKey]*Properties), 133} 134 135func extensionProperties(ed *ExtensionDesc) *Properties { 136 key := extPropKey{base: reflect.TypeOf(ed.ExtendedType), field: ed.Field} 137 138 extProp.RLock() 139 if prop, ok := extProp.m[key]; ok { 140 extProp.RUnlock() 141 return prop 142 } 143 extProp.RUnlock() 144 145 extProp.Lock() 146 defer extProp.Unlock() 147 // Check again. 148 if prop, ok := extProp.m[key]; ok { 149 return prop 150 } 151 152 prop := new(Properties) 153 prop.Init(reflect.TypeOf(ed.ExtensionType), "unknown_name", ed.Tag, nil) 154 extProp.m[key] = prop 155 return prop 156} 157 158// encodeExtensionMap encodes any unmarshaled (unencoded) extensions in m. 159func encodeExtensionMap(m map[int32]Extension) error { 160 for k, e := range m { 161 if e.value == nil || e.desc == nil { 162 // Extension is only in its encoded form. 163 continue 164 } 165 166 // We don't skip extensions that have an encoded form set, 167 // because the extension value may have been mutated after 168 // the last time this function was called. 169 170 et := reflect.TypeOf(e.desc.ExtensionType) 171 props := extensionProperties(e.desc) 172 173 p := NewBuffer(nil) 174 // If e.value has type T, the encoder expects a *struct{ X T }. 175 // Pass a *T with a zero field and hope it all works out. 176 x := reflect.New(et) 177 x.Elem().Set(reflect.ValueOf(e.value)) 178 if err := props.enc(p, props, toStructPointer(x)); err != nil { 179 return err 180 } 181 e.enc = p.buf 182 m[k] = e 183 } 184 return nil 185} 186 187func sizeExtensionMap(m map[int32]Extension) (n int) { 188 for _, e := range m { 189 if e.value == nil || e.desc == nil { 190 // Extension is only in its encoded form. 191 n += len(e.enc) 192 continue 193 } 194 195 // We don't skip extensions that have an encoded form set, 196 // because the extension value may have been mutated after 197 // the last time this function was called. 198 199 et := reflect.TypeOf(e.desc.ExtensionType) 200 props := extensionProperties(e.desc) 201 202 // If e.value has type T, the encoder expects a *struct{ X T }. 203 // Pass a *T with a zero field and hope it all works out. 204 x := reflect.New(et) 205 x.Elem().Set(reflect.ValueOf(e.value)) 206 n += props.size(props, toStructPointer(x)) 207 } 208 return 209} 210 211// HasExtension returns whether the given extension is present in pb. 212func HasExtension(pb extendableProto, extension *ExtensionDesc) bool { 213 // TODO: Check types, field numbers, etc.? 214 _, ok := pb.ExtensionMap()[extension.Field] 215 return ok 216} 217 218// ClearExtension removes the given extension from pb. 219func ClearExtension(pb extendableProto, extension *ExtensionDesc) { 220 // TODO: Check types, field numbers, etc.? 221 delete(pb.ExtensionMap(), extension.Field) 222} 223 224// GetExtension parses and returns the given extension of pb. 225// If the extension is not present it returns ErrMissingExtension. 226func GetExtension(pb extendableProto, extension *ExtensionDesc) (interface{}, error) { 227 if err := checkExtensionTypes(pb, extension); err != nil { 228 return nil, err 229 } 230 231 emap := pb.ExtensionMap() 232 e, ok := emap[extension.Field] 233 if !ok { 234 return nil, ErrMissingExtension 235 } 236 if e.value != nil { 237 // Already decoded. Check the descriptor, though. 238 if e.desc != extension { 239 // This shouldn't happen. If it does, it means that 240 // GetExtension was called twice with two different 241 // descriptors with the same field number. 242 return nil, errors.New("proto: descriptor conflict") 243 } 244 return e.value, nil 245 } 246 247 v, err := decodeExtension(e.enc, extension) 248 if err != nil { 249 return nil, err 250 } 251 252 // Remember the decoded version and drop the encoded version. 253 // That way it is safe to mutate what we return. 254 e.value = v 255 e.desc = extension 256 e.enc = nil 257 emap[extension.Field] = e 258 return e.value, nil 259} 260 261// decodeExtension decodes an extension encoded in b. 262func decodeExtension(b []byte, extension *ExtensionDesc) (interface{}, error) { 263 o := NewBuffer(b) 264 265 t := reflect.TypeOf(extension.ExtensionType) 266 rep := extension.repeated() 267 268 props := extensionProperties(extension) 269 270 // t is a pointer to a struct, pointer to basic type or a slice. 271 // Allocate a "field" to store the pointer/slice itself; the 272 // pointer/slice will be stored here. We pass 273 // the address of this field to props.dec. 274 // This passes a zero field and a *t and lets props.dec 275 // interpret it as a *struct{ x t }. 276 value := reflect.New(t).Elem() 277 278 for { 279 // Discard wire type and field number varint. It isn't needed. 280 if _, err := o.DecodeVarint(); err != nil { 281 return nil, err 282 } 283 284 if err := props.dec(o, props, toStructPointer(value.Addr())); err != nil { 285 return nil, err 286 } 287 288 if !rep || o.index >= len(o.buf) { 289 break 290 } 291 } 292 return value.Interface(), nil 293} 294 295// GetExtensions returns a slice of the extensions present in pb that are also listed in es. 296// The returned slice has the same length as es; missing extensions will appear as nil elements. 297func GetExtensions(pb Message, es []*ExtensionDesc) (extensions []interface{}, err error) { 298 epb, ok := pb.(extendableProto) 299 if !ok { 300 err = errors.New("proto: not an extendable proto") 301 return 302 } 303 extensions = make([]interface{}, len(es)) 304 for i, e := range es { 305 extensions[i], err = GetExtension(epb, e) 306 if err == ErrMissingExtension { 307 err = nil 308 } 309 if err != nil { 310 return 311 } 312 } 313 return 314} 315 316// SetExtension sets the specified extension of pb to the specified value. 317func SetExtension(pb extendableProto, extension *ExtensionDesc, value interface{}) error { 318 if err := checkExtensionTypes(pb, extension); err != nil { 319 return err 320 } 321 typ := reflect.TypeOf(extension.ExtensionType) 322 if typ != reflect.TypeOf(value) { 323 return errors.New("proto: bad extension value type") 324 } 325 // nil extension values need to be caught early, because the 326 // encoder can't distinguish an ErrNil due to a nil extension 327 // from an ErrNil due to a missing field. Extensions are 328 // always optional, so the encoder would just swallow the error 329 // and drop all the extensions from the encoded message. 330 if reflect.ValueOf(value).IsNil() { 331 return fmt.Errorf("proto: SetExtension called with nil value of type %T", value) 332 } 333 334 pb.ExtensionMap()[extension.Field] = Extension{desc: extension, value: value} 335 return nil 336} 337 338// A global registry of extensions. 339// The generated code will register the generated descriptors by calling RegisterExtension. 340 341var extensionMaps = make(map[reflect.Type]map[int32]*ExtensionDesc) 342 343// RegisterExtension is called from the generated code. 344func RegisterExtension(desc *ExtensionDesc) { 345 st := reflect.TypeOf(desc.ExtendedType).Elem() 346 m := extensionMaps[st] 347 if m == nil { 348 m = make(map[int32]*ExtensionDesc) 349 extensionMaps[st] = m 350 } 351 if _, ok := m[desc.Field]; ok { 352 panic("proto: duplicate extension registered: " + st.String() + " " + strconv.Itoa(int(desc.Field))) 353 } 354 m[desc.Field] = desc 355} 356 357// RegisteredExtensions returns a map of the registered extensions of a 358// protocol buffer struct, indexed by the extension number. 359// The argument pb should be a nil pointer to the struct type. 360func RegisteredExtensions(pb Message) map[int32]*ExtensionDesc { 361 return extensionMaps[reflect.TypeOf(pb).Elem()] 362} 363