1// Go support for Protocol Buffers - Google's data interchange format 2// 3// Copyright 2014 The Go Authors. All rights reserved. 4// https://github.com/golang/protobuf 5// 6// Redistribution and use in source and binary forms, with or without 7// modification, are permitted provided that the following conditions are 8// met: 9// 10// * Redistributions of source code must retain the above copyright 11// notice, this list of conditions and the following disclaimer. 12// * Redistributions in binary form must reproduce the above 13// copyright notice, this list of conditions and the following disclaimer 14// in the documentation and/or other materials provided with the 15// distribution. 16// * Neither the name of Google Inc. nor the names of its 17// contributors may be used to endorse or promote products derived from 18// this software without specific prior written permission. 19// 20// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 21// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 22// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 23// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 24// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 25// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 26// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 27// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 28// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 29// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 30// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 31 32package proto_test 33 34import ( 35 "bytes" 36 "fmt" 37 "io" 38 "reflect" 39 "sort" 40 "strings" 41 "sync" 42 "testing" 43 44 "github.com/golang/protobuf/proto" 45 pb "github.com/golang/protobuf/proto/test_proto" 46) 47 48func TestGetExtensionsWithMissingExtensions(t *testing.T) { 49 msg := &pb.MyMessage{} 50 ext1 := &pb.Ext{} 51 if err := proto.SetExtension(msg, pb.E_Ext_More, ext1); err != nil { 52 t.Fatalf("Could not set ext1: %s", err) 53 } 54 exts, err := proto.GetExtensions(msg, []*proto.ExtensionDesc{ 55 pb.E_Ext_More, 56 pb.E_Ext_Text, 57 }) 58 if err != nil { 59 t.Fatalf("GetExtensions() failed: %s", err) 60 } 61 if exts[0] != ext1 { 62 t.Errorf("ext1 not in returned extensions: %T %v", exts[0], exts[0]) 63 } 64 if exts[1] != nil { 65 t.Errorf("ext2 in returned extensions: %T %v", exts[1], exts[1]) 66 } 67} 68 69func TestGetExtensionWithEmptyBuffer(t *testing.T) { 70 // Make sure that GetExtension returns an error if its 71 // undecoded buffer is empty. 72 msg := &pb.MyMessage{} 73 proto.SetRawExtension(msg, pb.E_Ext_More.Field, []byte{}) 74 _, err := proto.GetExtension(msg, pb.E_Ext_More) 75 if want := io.ErrUnexpectedEOF; err != want { 76 t.Errorf("unexpected error in GetExtension from empty buffer: got %v, want %v", err, want) 77 } 78} 79 80func TestGetExtensionForIncompleteDesc(t *testing.T) { 81 msg := &pb.MyMessage{Count: proto.Int32(0)} 82 extdesc1 := &proto.ExtensionDesc{ 83 ExtendedType: (*pb.MyMessage)(nil), 84 ExtensionType: (*bool)(nil), 85 Field: 123456789, 86 Name: "a.b", 87 Tag: "varint,123456789,opt", 88 } 89 ext1 := proto.Bool(true) 90 if err := proto.SetExtension(msg, extdesc1, ext1); err != nil { 91 t.Fatalf("Could not set ext1: %s", err) 92 } 93 extdesc2 := &proto.ExtensionDesc{ 94 ExtendedType: (*pb.MyMessage)(nil), 95 ExtensionType: ([]byte)(nil), 96 Field: 123456790, 97 Name: "a.c", 98 Tag: "bytes,123456790,opt", 99 } 100 ext2 := []byte{0, 1, 2, 3, 4, 5, 6, 7} 101 if err := proto.SetExtension(msg, extdesc2, ext2); err != nil { 102 t.Fatalf("Could not set ext2: %s", err) 103 } 104 extdesc3 := &proto.ExtensionDesc{ 105 ExtendedType: (*pb.MyMessage)(nil), 106 ExtensionType: (*pb.Ext)(nil), 107 Field: 123456791, 108 Name: "a.d", 109 Tag: "bytes,123456791,opt", 110 } 111 ext3 := &pb.Ext{Data: proto.String("foo")} 112 if err := proto.SetExtension(msg, extdesc3, ext3); err != nil { 113 t.Fatalf("Could not set ext3: %s", err) 114 } 115 116 b, err := proto.Marshal(msg) 117 if err != nil { 118 t.Fatalf("Could not marshal msg: %v", err) 119 } 120 if err := proto.Unmarshal(b, msg); err != nil { 121 t.Fatalf("Could not unmarshal into msg: %v", err) 122 } 123 124 var expected proto.Buffer 125 if err := expected.EncodeVarint(uint64((extdesc1.Field << 3) | proto.WireVarint)); err != nil { 126 t.Fatalf("failed to compute expected prefix for ext1: %s", err) 127 } 128 if err := expected.EncodeVarint(1 /* bool true */); err != nil { 129 t.Fatalf("failed to compute expected value for ext1: %s", err) 130 } 131 132 if b, err := proto.GetExtension(msg, &proto.ExtensionDesc{Field: extdesc1.Field}); err != nil { 133 t.Fatalf("Failed to get raw value for ext1: %s", err) 134 } else if !reflect.DeepEqual(b, expected.Bytes()) { 135 t.Fatalf("Raw value for ext1: got %v, want %v", b, expected.Bytes()) 136 } 137 138 expected = proto.Buffer{} // reset 139 if err := expected.EncodeVarint(uint64((extdesc2.Field << 3) | proto.WireBytes)); err != nil { 140 t.Fatalf("failed to compute expected prefix for ext2: %s", err) 141 } 142 if err := expected.EncodeRawBytes(ext2); err != nil { 143 t.Fatalf("failed to compute expected value for ext2: %s", err) 144 } 145 146 if b, err := proto.GetExtension(msg, &proto.ExtensionDesc{Field: extdesc2.Field}); err != nil { 147 t.Fatalf("Failed to get raw value for ext2: %s", err) 148 } else if !reflect.DeepEqual(b, expected.Bytes()) { 149 t.Fatalf("Raw value for ext2: got %v, want %v", b, expected.Bytes()) 150 } 151 152 expected = proto.Buffer{} // reset 153 if err := expected.EncodeVarint(uint64((extdesc3.Field << 3) | proto.WireBytes)); err != nil { 154 t.Fatalf("failed to compute expected prefix for ext3: %s", err) 155 } 156 if b, err := proto.Marshal(ext3); err != nil { 157 t.Fatalf("failed to compute expected value for ext3: %s", err) 158 } else if err := expected.EncodeRawBytes(b); err != nil { 159 t.Fatalf("failed to compute expected value for ext3: %s", err) 160 } 161 162 if b, err := proto.GetExtension(msg, &proto.ExtensionDesc{Field: extdesc3.Field}); err != nil { 163 t.Fatalf("Failed to get raw value for ext3: %s", err) 164 } else if !reflect.DeepEqual(b, expected.Bytes()) { 165 t.Fatalf("Raw value for ext3: got %v, want %v", b, expected.Bytes()) 166 } 167} 168 169func TestExtensionDescsWithUnregisteredExtensions(t *testing.T) { 170 msg := &pb.MyMessage{Count: proto.Int32(0)} 171 extdesc1 := pb.E_Ext_More 172 if descs, err := proto.ExtensionDescs(msg); len(descs) != 0 || err != nil { 173 t.Errorf("proto.ExtensionDescs: got %d descs, error %v; want 0, nil", len(descs), err) 174 } 175 176 ext1 := &pb.Ext{} 177 if err := proto.SetExtension(msg, extdesc1, ext1); err != nil { 178 t.Fatalf("Could not set ext1: %s", err) 179 } 180 extdesc2 := &proto.ExtensionDesc{ 181 ExtendedType: (*pb.MyMessage)(nil), 182 ExtensionType: (*bool)(nil), 183 Field: 123456789, 184 Name: "a.b", 185 Tag: "varint,123456789,opt", 186 } 187 ext2 := proto.Bool(false) 188 if err := proto.SetExtension(msg, extdesc2, ext2); err != nil { 189 t.Fatalf("Could not set ext2: %s", err) 190 } 191 192 b, err := proto.Marshal(msg) 193 if err != nil { 194 t.Fatalf("Could not marshal msg: %v", err) 195 } 196 if err := proto.Unmarshal(b, msg); err != nil { 197 t.Fatalf("Could not unmarshal into msg: %v", err) 198 } 199 200 descs, err := proto.ExtensionDescs(msg) 201 if err != nil { 202 t.Fatalf("proto.ExtensionDescs: got error %v", err) 203 } 204 sortExtDescs(descs) 205 wantDescs := []*proto.ExtensionDesc{extdesc1, {Field: extdesc2.Field}} 206 if !reflect.DeepEqual(descs, wantDescs) { 207 t.Errorf("proto.ExtensionDescs(msg) sorted extension ids: got %+v, want %+v", descs, wantDescs) 208 } 209} 210 211type ExtensionDescSlice []*proto.ExtensionDesc 212 213func (s ExtensionDescSlice) Len() int { return len(s) } 214func (s ExtensionDescSlice) Less(i, j int) bool { return s[i].Field < s[j].Field } 215func (s ExtensionDescSlice) Swap(i, j int) { s[i], s[j] = s[j], s[i] } 216 217func sortExtDescs(s []*proto.ExtensionDesc) { 218 sort.Sort(ExtensionDescSlice(s)) 219} 220 221func TestGetExtensionStability(t *testing.T) { 222 check := func(m *pb.MyMessage) bool { 223 ext1, err := proto.GetExtension(m, pb.E_Ext_More) 224 if err != nil { 225 t.Fatalf("GetExtension() failed: %s", err) 226 } 227 ext2, err := proto.GetExtension(m, pb.E_Ext_More) 228 if err != nil { 229 t.Fatalf("GetExtension() failed: %s", err) 230 } 231 return ext1 == ext2 232 } 233 msg := &pb.MyMessage{Count: proto.Int32(4)} 234 ext0 := &pb.Ext{} 235 if err := proto.SetExtension(msg, pb.E_Ext_More, ext0); err != nil { 236 t.Fatalf("Could not set ext1: %s", ext0) 237 } 238 if !check(msg) { 239 t.Errorf("GetExtension() not stable before marshaling") 240 } 241 bb, err := proto.Marshal(msg) 242 if err != nil { 243 t.Fatalf("Marshal() failed: %s", err) 244 } 245 msg1 := &pb.MyMessage{} 246 err = proto.Unmarshal(bb, msg1) 247 if err != nil { 248 t.Fatalf("Unmarshal() failed: %s", err) 249 } 250 if !check(msg1) { 251 t.Errorf("GetExtension() not stable after unmarshaling") 252 } 253} 254 255func TestGetExtensionDefaults(t *testing.T) { 256 var setFloat64 float64 = 1 257 var setFloat32 float32 = 2 258 var setInt32 int32 = 3 259 var setInt64 int64 = 4 260 var setUint32 uint32 = 5 261 var setUint64 uint64 = 6 262 var setBool = true 263 var setBool2 = false 264 var setString = "Goodnight string" 265 var setBytes = []byte("Goodnight bytes") 266 var setEnum = pb.DefaultsMessage_TWO 267 268 type testcase struct { 269 ext *proto.ExtensionDesc // Extension we are testing. 270 want interface{} // Expected value of extension, or nil (meaning that GetExtension will fail). 271 def interface{} // Expected value of extension after ClearExtension(). 272 } 273 tests := []testcase{ 274 {pb.E_NoDefaultDouble, setFloat64, nil}, 275 {pb.E_NoDefaultFloat, setFloat32, nil}, 276 {pb.E_NoDefaultInt32, setInt32, nil}, 277 {pb.E_NoDefaultInt64, setInt64, nil}, 278 {pb.E_NoDefaultUint32, setUint32, nil}, 279 {pb.E_NoDefaultUint64, setUint64, nil}, 280 {pb.E_NoDefaultSint32, setInt32, nil}, 281 {pb.E_NoDefaultSint64, setInt64, nil}, 282 {pb.E_NoDefaultFixed32, setUint32, nil}, 283 {pb.E_NoDefaultFixed64, setUint64, nil}, 284 {pb.E_NoDefaultSfixed32, setInt32, nil}, 285 {pb.E_NoDefaultSfixed64, setInt64, nil}, 286 {pb.E_NoDefaultBool, setBool, nil}, 287 {pb.E_NoDefaultBool, setBool2, nil}, 288 {pb.E_NoDefaultString, setString, nil}, 289 {pb.E_NoDefaultBytes, setBytes, nil}, 290 {pb.E_NoDefaultEnum, setEnum, nil}, 291 {pb.E_DefaultDouble, setFloat64, float64(3.1415)}, 292 {pb.E_DefaultFloat, setFloat32, float32(3.14)}, 293 {pb.E_DefaultInt32, setInt32, int32(42)}, 294 {pb.E_DefaultInt64, setInt64, int64(43)}, 295 {pb.E_DefaultUint32, setUint32, uint32(44)}, 296 {pb.E_DefaultUint64, setUint64, uint64(45)}, 297 {pb.E_DefaultSint32, setInt32, int32(46)}, 298 {pb.E_DefaultSint64, setInt64, int64(47)}, 299 {pb.E_DefaultFixed32, setUint32, uint32(48)}, 300 {pb.E_DefaultFixed64, setUint64, uint64(49)}, 301 {pb.E_DefaultSfixed32, setInt32, int32(50)}, 302 {pb.E_DefaultSfixed64, setInt64, int64(51)}, 303 {pb.E_DefaultBool, setBool, true}, 304 {pb.E_DefaultBool, setBool2, true}, 305 {pb.E_DefaultString, setString, "Hello, string,def=foo"}, 306 {pb.E_DefaultBytes, setBytes, []byte("Hello, bytes")}, 307 {pb.E_DefaultEnum, setEnum, pb.DefaultsMessage_ONE}, 308 } 309 310 checkVal := func(test testcase, msg *pb.DefaultsMessage, valWant interface{}) error { 311 val, err := proto.GetExtension(msg, test.ext) 312 if err != nil { 313 if valWant != nil { 314 return fmt.Errorf("GetExtension(): %s", err) 315 } 316 if want := proto.ErrMissingExtension; err != want { 317 return fmt.Errorf("Unexpected error: got %v, want %v", err, want) 318 } 319 return nil 320 } 321 322 // All proto2 extension values are either a pointer to a value or a slice of values. 323 ty := reflect.TypeOf(val) 324 tyWant := reflect.TypeOf(test.ext.ExtensionType) 325 if got, want := ty, tyWant; got != want { 326 return fmt.Errorf("unexpected reflect.TypeOf(): got %v want %v", got, want) 327 } 328 tye := ty.Elem() 329 tyeWant := tyWant.Elem() 330 if got, want := tye, tyeWant; got != want { 331 return fmt.Errorf("unexpected reflect.TypeOf().Elem(): got %v want %v", got, want) 332 } 333 334 // Check the name of the type of the value. 335 // If it is an enum it will be type int32 with the name of the enum. 336 if got, want := tye.Name(), tye.Name(); got != want { 337 return fmt.Errorf("unexpected reflect.TypeOf().Elem().Name(): got %v want %v", got, want) 338 } 339 340 // Check that value is what we expect. 341 // If we have a pointer in val, get the value it points to. 342 valExp := val 343 if ty.Kind() == reflect.Ptr { 344 valExp = reflect.ValueOf(val).Elem().Interface() 345 } 346 if got, want := valExp, valWant; !reflect.DeepEqual(got, want) { 347 return fmt.Errorf("unexpected reflect.DeepEqual(): got %v want %v", got, want) 348 } 349 350 return nil 351 } 352 353 setTo := func(test testcase) interface{} { 354 setTo := reflect.ValueOf(test.want) 355 if typ := reflect.TypeOf(test.ext.ExtensionType); typ.Kind() == reflect.Ptr { 356 setTo = reflect.New(typ).Elem() 357 setTo.Set(reflect.New(setTo.Type().Elem())) 358 setTo.Elem().Set(reflect.ValueOf(test.want)) 359 } 360 return setTo.Interface() 361 } 362 363 for _, test := range tests { 364 msg := &pb.DefaultsMessage{} 365 name := test.ext.Name 366 367 // Check the initial value. 368 if err := checkVal(test, msg, test.def); err != nil { 369 t.Errorf("%s: %v", name, err) 370 } 371 372 // Set the per-type value and check value. 373 name = fmt.Sprintf("%s (set to %T %v)", name, test.want, test.want) 374 if err := proto.SetExtension(msg, test.ext, setTo(test)); err != nil { 375 t.Errorf("%s: SetExtension(): %v", name, err) 376 continue 377 } 378 if err := checkVal(test, msg, test.want); err != nil { 379 t.Errorf("%s: %v", name, err) 380 continue 381 } 382 383 // Set and check the value. 384 name += " (cleared)" 385 proto.ClearExtension(msg, test.ext) 386 if err := checkVal(test, msg, test.def); err != nil { 387 t.Errorf("%s: %v", name, err) 388 } 389 } 390} 391 392func TestNilMessage(t *testing.T) { 393 name := "nil interface" 394 if got, err := proto.GetExtension(nil, pb.E_Ext_More); err == nil { 395 t.Errorf("%s: got %T %v, expected to fail", name, got, got) 396 } else if !strings.Contains(err.Error(), "extendable") { 397 t.Errorf("%s: got error %v, expected not-extendable error", name, err) 398 } 399 400 // Regression tests: all functions of the Extension API 401 // used to panic when passed (*M)(nil), where M is a concrete message 402 // type. Now they handle this gracefully as a no-op or reported error. 403 var nilMsg *pb.MyMessage 404 desc := pb.E_Ext_More 405 406 isNotExtendable := func(err error) bool { 407 return strings.Contains(fmt.Sprint(err), "not extendable") 408 } 409 410 if proto.HasExtension(nilMsg, desc) { 411 t.Error("HasExtension(nil) = true") 412 } 413 414 if _, err := proto.GetExtensions(nilMsg, []*proto.ExtensionDesc{desc}); !isNotExtendable(err) { 415 t.Errorf("GetExtensions(nil) = %q (wrong error)", err) 416 } 417 418 if _, err := proto.ExtensionDescs(nilMsg); !isNotExtendable(err) { 419 t.Errorf("ExtensionDescs(nil) = %q (wrong error)", err) 420 } 421 422 if err := proto.SetExtension(nilMsg, desc, nil); !isNotExtendable(err) { 423 t.Errorf("SetExtension(nil) = %q (wrong error)", err) 424 } 425 426 proto.ClearExtension(nilMsg, desc) // no-op 427 proto.ClearAllExtensions(nilMsg) // no-op 428} 429 430func TestExtensionsRoundTrip(t *testing.T) { 431 msg := &pb.MyMessage{} 432 ext1 := &pb.Ext{ 433 Data: proto.String("hi"), 434 } 435 ext2 := &pb.Ext{ 436 Data: proto.String("there"), 437 } 438 exists := proto.HasExtension(msg, pb.E_Ext_More) 439 if exists { 440 t.Error("Extension More present unexpectedly") 441 } 442 if err := proto.SetExtension(msg, pb.E_Ext_More, ext1); err != nil { 443 t.Error(err) 444 } 445 if err := proto.SetExtension(msg, pb.E_Ext_More, ext2); err != nil { 446 t.Error(err) 447 } 448 e, err := proto.GetExtension(msg, pb.E_Ext_More) 449 if err != nil { 450 t.Error(err) 451 } 452 x, ok := e.(*pb.Ext) 453 if !ok { 454 t.Errorf("e has type %T, expected test_proto.Ext", e) 455 } else if *x.Data != "there" { 456 t.Errorf("SetExtension failed to overwrite, got %+v, not 'there'", x) 457 } 458 proto.ClearExtension(msg, pb.E_Ext_More) 459 if _, err = proto.GetExtension(msg, pb.E_Ext_More); err != proto.ErrMissingExtension { 460 t.Errorf("got %v, expected ErrMissingExtension", e) 461 } 462 if _, err := proto.GetExtension(msg, pb.E_X215); err == nil { 463 t.Error("expected bad extension error, got nil") 464 } 465 if err := proto.SetExtension(msg, pb.E_X215, 12); err == nil { 466 t.Error("expected extension err") 467 } 468 if err := proto.SetExtension(msg, pb.E_Ext_More, 12); err == nil { 469 t.Error("expected some sort of type mismatch error, got nil") 470 } 471} 472 473func TestNilExtension(t *testing.T) { 474 msg := &pb.MyMessage{ 475 Count: proto.Int32(1), 476 } 477 if err := proto.SetExtension(msg, pb.E_Ext_Text, proto.String("hello")); err != nil { 478 t.Fatal(err) 479 } 480 if err := proto.SetExtension(msg, pb.E_Ext_More, (*pb.Ext)(nil)); err == nil { 481 t.Error("expected SetExtension to fail due to a nil extension") 482 } else if want := fmt.Sprintf("proto: SetExtension called with nil value of type %T", new(pb.Ext)); err.Error() != want { 483 t.Errorf("expected error %v, got %v", want, err) 484 } 485 // Note: if the behavior of Marshal is ever changed to ignore nil extensions, update 486 // this test to verify that E_Ext_Text is properly propagated through marshal->unmarshal. 487} 488 489func TestMarshalUnmarshalRepeatedExtension(t *testing.T) { 490 // Add a repeated extension to the result. 491 tests := []struct { 492 name string 493 ext []*pb.ComplexExtension 494 }{ 495 { 496 "two fields", 497 []*pb.ComplexExtension{ 498 {First: proto.Int32(7)}, 499 {Second: proto.Int32(11)}, 500 }, 501 }, 502 { 503 "repeated field", 504 []*pb.ComplexExtension{ 505 {Third: []int32{1000}}, 506 {Third: []int32{2000}}, 507 }, 508 }, 509 { 510 "two fields and repeated field", 511 []*pb.ComplexExtension{ 512 {Third: []int32{1000}}, 513 {First: proto.Int32(9)}, 514 {Second: proto.Int32(21)}, 515 {Third: []int32{2000}}, 516 }, 517 }, 518 } 519 for _, test := range tests { 520 // Marshal message with a repeated extension. 521 msg1 := new(pb.OtherMessage) 522 err := proto.SetExtension(msg1, pb.E_RComplex, test.ext) 523 if err != nil { 524 t.Fatalf("[%s] Error setting extension: %v", test.name, err) 525 } 526 b, err := proto.Marshal(msg1) 527 if err != nil { 528 t.Fatalf("[%s] Error marshaling message: %v", test.name, err) 529 } 530 531 // Unmarshal and read the merged proto. 532 msg2 := new(pb.OtherMessage) 533 err = proto.Unmarshal(b, msg2) 534 if err != nil { 535 t.Fatalf("[%s] Error unmarshaling message: %v", test.name, err) 536 } 537 e, err := proto.GetExtension(msg2, pb.E_RComplex) 538 if err != nil { 539 t.Fatalf("[%s] Error getting extension: %v", test.name, err) 540 } 541 ext := e.([]*pb.ComplexExtension) 542 if ext == nil { 543 t.Fatalf("[%s] Invalid extension", test.name) 544 } 545 if len(ext) != len(test.ext) { 546 t.Errorf("[%s] Wrong length of ComplexExtension: got: %v want: %v\n", test.name, len(ext), len(test.ext)) 547 } 548 for i := range test.ext { 549 if !proto.Equal(ext[i], test.ext[i]) { 550 t.Errorf("[%s] Wrong value for ComplexExtension[%d]: got: %v want: %v\n", test.name, i, ext[i], test.ext[i]) 551 } 552 } 553 } 554} 555 556func TestUnmarshalRepeatingNonRepeatedExtension(t *testing.T) { 557 // We may see multiple instances of the same extension in the wire 558 // format. For example, the proto compiler may encode custom options in 559 // this way. Here, we verify that we merge the extensions together. 560 tests := []struct { 561 name string 562 ext []*pb.ComplexExtension 563 }{ 564 { 565 "two fields", 566 []*pb.ComplexExtension{ 567 {First: proto.Int32(7)}, 568 {Second: proto.Int32(11)}, 569 }, 570 }, 571 { 572 "repeated field", 573 []*pb.ComplexExtension{ 574 {Third: []int32{1000}}, 575 {Third: []int32{2000}}, 576 }, 577 }, 578 { 579 "two fields and repeated field", 580 []*pb.ComplexExtension{ 581 {Third: []int32{1000}}, 582 {First: proto.Int32(9)}, 583 {Second: proto.Int32(21)}, 584 {Third: []int32{2000}}, 585 }, 586 }, 587 } 588 for _, test := range tests { 589 var buf bytes.Buffer 590 var want pb.ComplexExtension 591 592 // Generate a serialized representation of a repeated extension 593 // by catenating bytes together. 594 for i, e := range test.ext { 595 // Merge to create the wanted proto. 596 proto.Merge(&want, e) 597 598 // serialize the message 599 msg := new(pb.OtherMessage) 600 err := proto.SetExtension(msg, pb.E_Complex, e) 601 if err != nil { 602 t.Fatalf("[%s] Error setting extension %d: %v", test.name, i, err) 603 } 604 b, err := proto.Marshal(msg) 605 if err != nil { 606 t.Fatalf("[%s] Error marshaling message %d: %v", test.name, i, err) 607 } 608 buf.Write(b) 609 } 610 611 // Unmarshal and read the merged proto. 612 msg2 := new(pb.OtherMessage) 613 err := proto.Unmarshal(buf.Bytes(), msg2) 614 if err != nil { 615 t.Fatalf("[%s] Error unmarshaling message: %v", test.name, err) 616 } 617 e, err := proto.GetExtension(msg2, pb.E_Complex) 618 if err != nil { 619 t.Fatalf("[%s] Error getting extension: %v", test.name, err) 620 } 621 ext := e.(*pb.ComplexExtension) 622 if ext == nil { 623 t.Fatalf("[%s] Invalid extension", test.name) 624 } 625 if !proto.Equal(ext, &want) { 626 t.Errorf("[%s] Wrong value for ComplexExtension: got: %s want: %s\n", test.name, ext, &want) 627 } 628 } 629} 630 631func TestClearAllExtensions(t *testing.T) { 632 // unregistered extension 633 desc := &proto.ExtensionDesc{ 634 ExtendedType: (*pb.MyMessage)(nil), 635 ExtensionType: (*bool)(nil), 636 Field: 101010100, 637 Name: "emptyextension", 638 Tag: "varint,0,opt", 639 } 640 m := &pb.MyMessage{} 641 if proto.HasExtension(m, desc) { 642 t.Errorf("proto.HasExtension(%s): got true, want false", proto.MarshalTextString(m)) 643 } 644 if err := proto.SetExtension(m, desc, proto.Bool(true)); err != nil { 645 t.Errorf("proto.SetExtension(m, desc, true): got error %q, want nil", err) 646 } 647 if !proto.HasExtension(m, desc) { 648 t.Errorf("proto.HasExtension(%s): got false, want true", proto.MarshalTextString(m)) 649 } 650 proto.ClearAllExtensions(m) 651 if proto.HasExtension(m, desc) { 652 t.Errorf("proto.HasExtension(%s): got true, want false", proto.MarshalTextString(m)) 653 } 654} 655 656func TestMarshalRace(t *testing.T) { 657 ext := &pb.Ext{} 658 m := &pb.MyMessage{Count: proto.Int32(4)} 659 if err := proto.SetExtension(m, pb.E_Ext_More, ext); err != nil { 660 t.Fatalf("proto.SetExtension(m, desc, true): got error %q, want nil", err) 661 } 662 663 b, err := proto.Marshal(m) 664 if err != nil { 665 t.Fatalf("Could not marshal message: %v", err) 666 } 667 if err := proto.Unmarshal(b, m); err != nil { 668 t.Fatalf("Could not unmarshal message: %v", err) 669 } 670 // after Unmarshal, the extension is in undecoded form. 671 // GetExtension will decode it lazily. Make sure this does 672 // not race against Marshal. 673 674 wg := sync.WaitGroup{} 675 errs := make(chan error, 3) 676 for n := 3; n > 0; n-- { 677 wg.Add(1) 678 go func() { 679 defer wg.Done() 680 _, err := proto.Marshal(m) 681 errs <- err 682 }() 683 } 684 wg.Wait() 685 close(errs) 686 687 for err = range errs { 688 if err != nil { 689 t.Fatal(err) 690 } 691 } 692} 693