1// Protocol Buffers for Go with Gadgets 2// 3// Copyright (c) 2013, The GoGo Authors. All rights reserved. 4// http://github.com/gogo/protobuf 5// 6// Redistribution and use in source and binary forms, with or without 7// modification, are permitted provided that the following conditions are 8// met: 9// 10// * Redistributions of source code must retain the above copyright 11// notice, this list of conditions and the following disclaimer. 12// * Redistributions in binary form must reproduce the above 13// copyright notice, this list of conditions and the following disclaimer 14// in the documentation and/or other materials provided with the 15// distribution. 16// 17// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 18// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 19// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 20// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 21// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 22// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 23// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 24// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 25// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 26// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 27// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 28 29package proto 30 31import ( 32 "bytes" 33 "errors" 34 "fmt" 35 "io" 36 "reflect" 37 "sort" 38 "strings" 39 "sync" 40) 41 42type extensionsBytes interface { 43 Message 44 ExtensionRangeArray() []ExtensionRange 45 GetExtensions() *[]byte 46} 47 48type slowExtensionAdapter struct { 49 extensionsBytes 50} 51 52func (s slowExtensionAdapter) extensionsWrite() map[int32]Extension { 53 panic("Please report a bug to github.com/gogo/protobuf if you see this message: Writing extensions is not supported for extensions stored in a byte slice field.") 54} 55 56func (s slowExtensionAdapter) extensionsRead() (map[int32]Extension, sync.Locker) { 57 b := s.GetExtensions() 58 m, err := BytesToExtensionsMap(*b) 59 if err != nil { 60 panic(err) 61 } 62 return m, notLocker{} 63} 64 65func GetBoolExtension(pb Message, extension *ExtensionDesc, ifnotset bool) bool { 66 if reflect.ValueOf(pb).IsNil() { 67 return ifnotset 68 } 69 value, err := GetExtension(pb, extension) 70 if err != nil { 71 return ifnotset 72 } 73 if value == nil { 74 return ifnotset 75 } 76 if value.(*bool) == nil { 77 return ifnotset 78 } 79 return *(value.(*bool)) 80} 81 82func (this *Extension) Equal(that *Extension) bool { 83 if err := this.Encode(); err != nil { 84 return false 85 } 86 if err := that.Encode(); err != nil { 87 return false 88 } 89 return bytes.Equal(this.enc, that.enc) 90} 91 92func (this *Extension) Compare(that *Extension) int { 93 if err := this.Encode(); err != nil { 94 return 1 95 } 96 if err := that.Encode(); err != nil { 97 return -1 98 } 99 return bytes.Compare(this.enc, that.enc) 100} 101 102func SizeOfInternalExtension(m extendableProto) (n int) { 103 info := getMarshalInfo(reflect.TypeOf(m)) 104 return info.sizeV1Extensions(m.extensionsWrite()) 105} 106 107type sortableMapElem struct { 108 field int32 109 ext Extension 110} 111 112func newSortableExtensionsFromMap(m map[int32]Extension) sortableExtensions { 113 s := make(sortableExtensions, 0, len(m)) 114 for k, v := range m { 115 s = append(s, &sortableMapElem{field: k, ext: v}) 116 } 117 return s 118} 119 120type sortableExtensions []*sortableMapElem 121 122func (this sortableExtensions) Len() int { return len(this) } 123 124func (this sortableExtensions) Swap(i, j int) { this[i], this[j] = this[j], this[i] } 125 126func (this sortableExtensions) Less(i, j int) bool { return this[i].field < this[j].field } 127 128func (this sortableExtensions) String() string { 129 sort.Sort(this) 130 ss := make([]string, len(this)) 131 for i := range this { 132 ss[i] = fmt.Sprintf("%d: %v", this[i].field, this[i].ext) 133 } 134 return "map[" + strings.Join(ss, ",") + "]" 135} 136 137func StringFromInternalExtension(m extendableProto) string { 138 return StringFromExtensionsMap(m.extensionsWrite()) 139} 140 141func StringFromExtensionsMap(m map[int32]Extension) string { 142 return newSortableExtensionsFromMap(m).String() 143} 144 145func StringFromExtensionsBytes(ext []byte) string { 146 m, err := BytesToExtensionsMap(ext) 147 if err != nil { 148 panic(err) 149 } 150 return StringFromExtensionsMap(m) 151} 152 153func EncodeInternalExtension(m extendableProto, data []byte) (n int, err error) { 154 return EncodeExtensionMap(m.extensionsWrite(), data) 155} 156 157func EncodeExtensionMap(m map[int32]Extension, data []byte) (n int, err error) { 158 o := 0 159 for _, e := range m { 160 if err := e.Encode(); err != nil { 161 return 0, err 162 } 163 n := copy(data[o:], e.enc) 164 if n != len(e.enc) { 165 return 0, io.ErrShortBuffer 166 } 167 o += n 168 } 169 return o, nil 170} 171 172func GetRawExtension(m map[int32]Extension, id int32) ([]byte, error) { 173 e := m[id] 174 if err := e.Encode(); err != nil { 175 return nil, err 176 } 177 return e.enc, nil 178} 179 180func size(buf []byte, wire int) (int, error) { 181 switch wire { 182 case WireVarint: 183 _, n := DecodeVarint(buf) 184 return n, nil 185 case WireFixed64: 186 return 8, nil 187 case WireBytes: 188 v, n := DecodeVarint(buf) 189 return int(v) + n, nil 190 case WireFixed32: 191 return 4, nil 192 case WireStartGroup: 193 offset := 0 194 for { 195 u, n := DecodeVarint(buf[offset:]) 196 fwire := int(u & 0x7) 197 offset += n 198 if fwire == WireEndGroup { 199 return offset, nil 200 } 201 s, err := size(buf[offset:], wire) 202 if err != nil { 203 return 0, err 204 } 205 offset += s 206 } 207 } 208 return 0, fmt.Errorf("proto: can't get size for unknown wire type %d", wire) 209} 210 211func BytesToExtensionsMap(buf []byte) (map[int32]Extension, error) { 212 m := make(map[int32]Extension) 213 i := 0 214 for i < len(buf) { 215 tag, n := DecodeVarint(buf[i:]) 216 if n <= 0 { 217 return nil, fmt.Errorf("unable to decode varint") 218 } 219 fieldNum := int32(tag >> 3) 220 wireType := int(tag & 0x7) 221 l, err := size(buf[i+n:], wireType) 222 if err != nil { 223 return nil, err 224 } 225 end := i + int(l) + n 226 m[int32(fieldNum)] = Extension{enc: buf[i:end]} 227 i = end 228 } 229 return m, nil 230} 231 232func NewExtension(e []byte) Extension { 233 ee := Extension{enc: make([]byte, len(e))} 234 copy(ee.enc, e) 235 return ee 236} 237 238func AppendExtension(e Message, tag int32, buf []byte) { 239 if ee, eok := e.(extensionsBytes); eok { 240 ext := ee.GetExtensions() 241 *ext = append(*ext, buf...) 242 return 243 } 244 if ee, eok := e.(extendableProto); eok { 245 m := ee.extensionsWrite() 246 ext := m[int32(tag)] // may be missing 247 ext.enc = append(ext.enc, buf...) 248 m[int32(tag)] = ext 249 } 250} 251 252func encodeExtension(extension *ExtensionDesc, value interface{}) ([]byte, error) { 253 u := getMarshalInfo(reflect.TypeOf(extension.ExtendedType)) 254 ei := u.getExtElemInfo(extension) 255 v := value 256 p := toAddrPointer(&v, ei.isptr) 257 siz := ei.sizer(p, SizeVarint(ei.wiretag)) 258 buf := make([]byte, 0, siz) 259 return ei.marshaler(buf, p, ei.wiretag, false) 260} 261 262func decodeExtensionFromBytes(extension *ExtensionDesc, buf []byte) (interface{}, error) { 263 o := 0 264 for o < len(buf) { 265 tag, n := DecodeVarint((buf)[o:]) 266 fieldNum := int32(tag >> 3) 267 wireType := int(tag & 0x7) 268 if o+n > len(buf) { 269 return nil, fmt.Errorf("unable to decode extension") 270 } 271 l, err := size((buf)[o+n:], wireType) 272 if err != nil { 273 return nil, err 274 } 275 if int32(fieldNum) == extension.Field { 276 if o+n+l > len(buf) { 277 return nil, fmt.Errorf("unable to decode extension") 278 } 279 v, err := decodeExtension((buf)[o:o+n+l], extension) 280 if err != nil { 281 return nil, err 282 } 283 return v, nil 284 } 285 o += n + l 286 } 287 return defaultExtensionValue(extension) 288} 289 290func (this *Extension) Encode() error { 291 if this.enc == nil { 292 var err error 293 this.enc, err = encodeExtension(this.desc, this.value) 294 if err != nil { 295 return err 296 } 297 } 298 return nil 299} 300 301func (this Extension) GoString() string { 302 if err := this.Encode(); err != nil { 303 return fmt.Sprintf("error encoding extension: %v", err) 304 } 305 return fmt.Sprintf("proto.NewExtension(%#v)", this.enc) 306} 307 308func SetUnsafeExtension(pb Message, fieldNum int32, value interface{}) error { 309 typ := reflect.TypeOf(pb).Elem() 310 ext, ok := extensionMaps[typ] 311 if !ok { 312 return fmt.Errorf("proto: bad extended type; %s is not extendable", typ.String()) 313 } 314 desc, ok := ext[fieldNum] 315 if !ok { 316 return errors.New("proto: bad extension number; not in declared ranges") 317 } 318 return SetExtension(pb, desc, value) 319} 320 321func GetUnsafeExtension(pb Message, fieldNum int32) (interface{}, error) { 322 typ := reflect.TypeOf(pb).Elem() 323 ext, ok := extensionMaps[typ] 324 if !ok { 325 return nil, fmt.Errorf("proto: bad extended type; %s is not extendable", typ.String()) 326 } 327 desc, ok := ext[fieldNum] 328 if !ok { 329 return nil, fmt.Errorf("unregistered field number %d", fieldNum) 330 } 331 return GetExtension(pb, desc) 332} 333 334func NewUnsafeXXX_InternalExtensions(m map[int32]Extension) XXX_InternalExtensions { 335 x := &XXX_InternalExtensions{ 336 p: new(struct { 337 mu sync.Mutex 338 extensionMap map[int32]Extension 339 }), 340 } 341 x.p.extensionMap = m 342 return *x 343} 344 345func GetUnsafeExtensionsMap(extendable Message) map[int32]Extension { 346 pb := extendable.(extendableProto) 347 return pb.extensionsWrite() 348} 349 350func deleteExtension(pb extensionsBytes, theFieldNum int32, offset int) int { 351 ext := pb.GetExtensions() 352 for offset < len(*ext) { 353 tag, n1 := DecodeVarint((*ext)[offset:]) 354 fieldNum := int32(tag >> 3) 355 wireType := int(tag & 0x7) 356 n2, err := size((*ext)[offset+n1:], wireType) 357 if err != nil { 358 panic(err) 359 } 360 newOffset := offset + n1 + n2 361 if fieldNum == theFieldNum { 362 *ext = append((*ext)[:offset], (*ext)[newOffset:]...) 363 return offset 364 } 365 offset = newOffset 366 } 367 return -1 368} 369