1// Copyright 2014 The Go Authors. All rights reserved. 2// Use of this source code is governed by a BSD-style 3// license that can be found in the LICENSE file. 4 5package proto_test 6 7import ( 8 "bytes" 9 "fmt" 10 "reflect" 11 "sort" 12 "strings" 13 "sync" 14 "testing" 15 16 "github.com/golang/protobuf/proto" 17 18 pb2 "github.com/golang/protobuf/internal/testprotos/proto2_proto" 19) 20 21func TestGetExtensionsWithMissingExtensions(t *testing.T) { 22 msg := &pb2.MyMessage{} 23 ext1 := &pb2.Ext{} 24 if err := proto.SetExtension(msg, pb2.E_Ext_More, ext1); err != nil { 25 t.Fatalf("Could not set ext1: %s", err) 26 } 27 exts, err := proto.GetExtensions(msg, []*proto.ExtensionDesc{ 28 pb2.E_Ext_More, 29 pb2.E_Ext_Text, 30 }) 31 if err != nil { 32 t.Fatalf("GetExtensions() failed: %s", err) 33 } 34 if exts[0] != ext1 { 35 t.Errorf("ext1 not in returned extensions: %T %v", exts[0], exts[0]) 36 } 37 if exts[1] != nil { 38 t.Errorf("ext2 in returned extensions: %T %v", exts[1], exts[1]) 39 } 40} 41 42func TestGetExtensionForIncompleteDesc(t *testing.T) { 43 msg := &pb2.MyMessage{Count: proto.Int32(0)} 44 extdesc1 := &proto.ExtensionDesc{ 45 ExtendedType: (*pb2.MyMessage)(nil), 46 ExtensionType: (*bool)(nil), 47 Field: 123456789, 48 Name: "a.b", 49 Tag: "varint,123456789,opt", 50 } 51 ext1 := proto.Bool(true) 52 if err := proto.SetExtension(msg, extdesc1, ext1); err != nil { 53 t.Fatalf("Could not set ext1: %s", err) 54 } 55 extdesc2 := &proto.ExtensionDesc{ 56 ExtendedType: (*pb2.MyMessage)(nil), 57 ExtensionType: ([]byte)(nil), 58 Field: 123456790, 59 Name: "a.c", 60 Tag: "bytes,123456790,opt", 61 } 62 ext2 := []byte{0, 1, 2, 3, 4, 5, 6, 7} 63 if err := proto.SetExtension(msg, extdesc2, ext2); err != nil { 64 t.Fatalf("Could not set ext2: %s", err) 65 } 66 extdesc3 := &proto.ExtensionDesc{ 67 ExtendedType: (*pb2.MyMessage)(nil), 68 ExtensionType: (*pb2.Ext)(nil), 69 Field: 123456791, 70 Name: "a.d", 71 Tag: "bytes,123456791,opt", 72 } 73 ext3 := &pb2.Ext{Data: proto.String("foo")} 74 if err := proto.SetExtension(msg, extdesc3, ext3); err != nil { 75 t.Fatalf("Could not set ext3: %s", err) 76 } 77 78 b, err := proto.Marshal(msg) 79 if err != nil { 80 t.Fatalf("Could not marshal msg: %v", err) 81 } 82 if err := proto.Unmarshal(b, msg); err != nil { 83 t.Fatalf("Could not unmarshal into msg: %v", err) 84 } 85 86 var expected proto.Buffer 87 if err := expected.EncodeVarint(uint64((extdesc1.Field << 3) | proto.WireVarint)); err != nil { 88 t.Fatalf("failed to compute expected prefix for ext1: %s", err) 89 } 90 if err := expected.EncodeVarint(1 /* bool true */); err != nil { 91 t.Fatalf("failed to compute expected value for ext1: %s", err) 92 } 93 94 if b, err := proto.GetExtension(msg, &proto.ExtensionDesc{Field: extdesc1.Field}); err != nil { 95 t.Fatalf("Failed to get raw value for ext1: %s", err) 96 } else if !reflect.DeepEqual(b, expected.Bytes()) { 97 t.Fatalf("Raw value for ext1: got %v, want %v", b, expected.Bytes()) 98 } 99 100 expected = proto.Buffer{} // reset 101 if err := expected.EncodeVarint(uint64((extdesc2.Field << 3) | proto.WireBytes)); err != nil { 102 t.Fatalf("failed to compute expected prefix for ext2: %s", err) 103 } 104 if err := expected.EncodeRawBytes(ext2); err != nil { 105 t.Fatalf("failed to compute expected value for ext2: %s", err) 106 } 107 108 if b, err := proto.GetExtension(msg, &proto.ExtensionDesc{Field: extdesc2.Field}); err != nil { 109 t.Fatalf("Failed to get raw value for ext2: %s", err) 110 } else if !reflect.DeepEqual(b, expected.Bytes()) { 111 t.Fatalf("Raw value for ext2: got %v, want %v", b, expected.Bytes()) 112 } 113 114 expected = proto.Buffer{} // reset 115 if err := expected.EncodeVarint(uint64((extdesc3.Field << 3) | proto.WireBytes)); err != nil { 116 t.Fatalf("failed to compute expected prefix for ext3: %s", err) 117 } 118 if b, err := proto.Marshal(ext3); err != nil { 119 t.Fatalf("failed to compute expected value for ext3: %s", err) 120 } else if err := expected.EncodeRawBytes(b); err != nil { 121 t.Fatalf("failed to compute expected value for ext3: %s", err) 122 } 123 124 if b, err := proto.GetExtension(msg, &proto.ExtensionDesc{Field: extdesc3.Field}); err != nil { 125 t.Fatalf("Failed to get raw value for ext3: %s", err) 126 } else if !reflect.DeepEqual(b, expected.Bytes()) { 127 t.Fatalf("Raw value for ext3: got %v, want %v", b, expected.Bytes()) 128 } 129} 130 131func TestExtensionDescsWithUnregisteredExtensions(t *testing.T) { 132 msg := &pb2.MyMessage{Count: proto.Int32(0)} 133 extdesc1 := pb2.E_Ext_More 134 if descs, err := proto.ExtensionDescs(msg); len(descs) != 0 || err != nil { 135 t.Errorf("proto.ExtensionDescs: got %d descs, error %v; want 0, nil", len(descs), err) 136 } 137 138 ext1 := &pb2.Ext{} 139 if err := proto.SetExtension(msg, extdesc1, ext1); err != nil { 140 t.Fatalf("Could not set ext1: %s", err) 141 } 142 extdesc2 := &proto.ExtensionDesc{ 143 ExtendedType: (*pb2.MyMessage)(nil), 144 ExtensionType: (*bool)(nil), 145 Field: 123456789, 146 Name: "a.b", 147 Tag: "varint,123456789,opt", 148 } 149 ext2 := proto.Bool(false) 150 if err := proto.SetExtension(msg, extdesc2, ext2); err != nil { 151 t.Fatalf("Could not set ext2: %s", err) 152 } 153 154 b, err := proto.Marshal(msg) 155 if err != nil { 156 t.Fatalf("Could not marshal msg: %v", err) 157 } 158 if err := proto.Unmarshal(b, msg); err != nil { 159 t.Fatalf("Could not unmarshal into msg: %v", err) 160 } 161 162 descs, err := proto.ExtensionDescs(msg) 163 if err != nil { 164 t.Fatalf("proto.ExtensionDescs: got error %v", err) 165 } 166 sortExtDescs(descs) 167 wantDescs := []*proto.ExtensionDesc{extdesc1, {Field: extdesc2.Field}} 168 if !reflect.DeepEqual(descs, wantDescs) { 169 t.Errorf("proto.ExtensionDescs(msg) sorted extension ids: got %+v, want %+v", descs, wantDescs) 170 } 171} 172 173type ExtensionDescSlice []*proto.ExtensionDesc 174 175func (s ExtensionDescSlice) Len() int { return len(s) } 176func (s ExtensionDescSlice) Less(i, j int) bool { return s[i].Field < s[j].Field } 177func (s ExtensionDescSlice) Swap(i, j int) { s[i], s[j] = s[j], s[i] } 178 179func sortExtDescs(s []*proto.ExtensionDesc) { 180 sort.Sort(ExtensionDescSlice(s)) 181} 182 183func TestGetExtensionStability(t *testing.T) { 184 check := func(m *pb2.MyMessage) bool { 185 ext1, err := proto.GetExtension(m, pb2.E_Ext_More) 186 if err != nil { 187 t.Fatalf("GetExtension() failed: %s", err) 188 } 189 ext2, err := proto.GetExtension(m, pb2.E_Ext_More) 190 if err != nil { 191 t.Fatalf("GetExtension() failed: %s", err) 192 } 193 return ext1 == ext2 194 } 195 msg := &pb2.MyMessage{Count: proto.Int32(4)} 196 ext0 := &pb2.Ext{} 197 if err := proto.SetExtension(msg, pb2.E_Ext_More, ext0); err != nil { 198 t.Fatalf("Could not set ext1: %s", ext0) 199 } 200 if !check(msg) { 201 t.Errorf("GetExtension() not stable before marshaling") 202 } 203 bb, err := proto.Marshal(msg) 204 if err != nil { 205 t.Fatalf("Marshal() failed: %s", err) 206 } 207 msg1 := &pb2.MyMessage{} 208 err = proto.Unmarshal(bb, msg1) 209 if err != nil { 210 t.Fatalf("Unmarshal() failed: %s", err) 211 } 212 if !check(msg1) { 213 t.Errorf("GetExtension() not stable after unmarshaling") 214 } 215} 216 217func TestGetExtensionDefaults(t *testing.T) { 218 var setFloat64 float64 = 1 219 var setFloat32 float32 = 2 220 var setInt32 int32 = 3 221 var setInt64 int64 = 4 222 var setUint32 uint32 = 5 223 var setUint64 uint64 = 6 224 var setBool = true 225 var setBool2 = false 226 var setString = "Goodnight string" 227 var setBytes = []byte("Goodnight bytes") 228 var setEnum = pb2.DefaultsMessage_TWO 229 230 type testcase struct { 231 ext *proto.ExtensionDesc // Extension we are testing. 232 want interface{} // Expected value of extension, or nil (meaning that GetExtension will fail). 233 def interface{} // Expected value of extension after ClearExtension(). 234 } 235 tests := []testcase{ 236 {pb2.E_NoDefaultDouble, setFloat64, nil}, 237 {pb2.E_NoDefaultFloat, setFloat32, nil}, 238 {pb2.E_NoDefaultInt32, setInt32, nil}, 239 {pb2.E_NoDefaultInt64, setInt64, nil}, 240 {pb2.E_NoDefaultUint32, setUint32, nil}, 241 {pb2.E_NoDefaultUint64, setUint64, nil}, 242 {pb2.E_NoDefaultSint32, setInt32, nil}, 243 {pb2.E_NoDefaultSint64, setInt64, nil}, 244 {pb2.E_NoDefaultFixed32, setUint32, nil}, 245 {pb2.E_NoDefaultFixed64, setUint64, nil}, 246 {pb2.E_NoDefaultSfixed32, setInt32, nil}, 247 {pb2.E_NoDefaultSfixed64, setInt64, nil}, 248 {pb2.E_NoDefaultBool, setBool, nil}, 249 {pb2.E_NoDefaultBool, setBool2, nil}, 250 {pb2.E_NoDefaultString, setString, nil}, 251 {pb2.E_NoDefaultBytes, setBytes, nil}, 252 {pb2.E_NoDefaultEnum, setEnum, nil}, 253 {pb2.E_DefaultDouble, setFloat64, float64(3.1415)}, 254 {pb2.E_DefaultFloat, setFloat32, float32(3.14)}, 255 {pb2.E_DefaultInt32, setInt32, int32(42)}, 256 {pb2.E_DefaultInt64, setInt64, int64(43)}, 257 {pb2.E_DefaultUint32, setUint32, uint32(44)}, 258 {pb2.E_DefaultUint64, setUint64, uint64(45)}, 259 {pb2.E_DefaultSint32, setInt32, int32(46)}, 260 {pb2.E_DefaultSint64, setInt64, int64(47)}, 261 {pb2.E_DefaultFixed32, setUint32, uint32(48)}, 262 {pb2.E_DefaultFixed64, setUint64, uint64(49)}, 263 {pb2.E_DefaultSfixed32, setInt32, int32(50)}, 264 {pb2.E_DefaultSfixed64, setInt64, int64(51)}, 265 {pb2.E_DefaultBool, setBool, true}, 266 {pb2.E_DefaultBool, setBool2, true}, 267 {pb2.E_DefaultString, setString, "Hello, string,def=foo"}, 268 {pb2.E_DefaultBytes, setBytes, []byte("Hello, bytes")}, 269 {pb2.E_DefaultEnum, setEnum, pb2.DefaultsMessage_ONE}, 270 } 271 272 checkVal := func(t *testing.T, name string, test testcase, msg *pb2.DefaultsMessage, valWant interface{}) { 273 t.Run(name, func(t *testing.T) { 274 val, err := proto.GetExtension(msg, test.ext) 275 if err != nil { 276 if valWant != nil { 277 t.Errorf("GetExtension(): %s", err) 278 return 279 } 280 if want := proto.ErrMissingExtension; err != want { 281 t.Errorf("Unexpected error: got %v, want %v", err, want) 282 return 283 } 284 return 285 } 286 287 // All proto2 extension values are either a pointer to a value or a slice of values. 288 ty := reflect.TypeOf(val) 289 tyWant := reflect.TypeOf(test.ext.ExtensionType) 290 if got, want := ty, tyWant; got != want { 291 t.Errorf("unexpected reflect.TypeOf(): got %v want %v", got, want) 292 return 293 } 294 tye := ty.Elem() 295 tyeWant := tyWant.Elem() 296 if got, want := tye, tyeWant; got != want { 297 t.Errorf("unexpected reflect.TypeOf().Elem(): got %v want %v", got, want) 298 return 299 } 300 301 // Check the name of the type of the value. 302 // If it is an enum it will be type int32 with the name of the enum. 303 if got, want := tye.Name(), tye.Name(); got != want { 304 t.Errorf("unexpected reflect.TypeOf().Elem().Name(): got %v want %v", got, want) 305 return 306 } 307 308 // Check that value is what we expect. 309 // If we have a pointer in val, get the value it points to. 310 valExp := val 311 if ty.Kind() == reflect.Ptr { 312 valExp = reflect.ValueOf(val).Elem().Interface() 313 } 314 if got, want := valExp, valWant; !reflect.DeepEqual(got, want) { 315 t.Errorf("unexpected reflect.DeepEqual(): got %v want %v", got, want) 316 return 317 } 318 }) 319 } 320 321 setTo := func(test testcase) interface{} { 322 setTo := reflect.ValueOf(test.want) 323 if typ := reflect.TypeOf(test.ext.ExtensionType); typ.Kind() == reflect.Ptr { 324 setTo = reflect.New(typ).Elem() 325 setTo.Set(reflect.New(setTo.Type().Elem())) 326 setTo.Elem().Set(reflect.ValueOf(test.want)) 327 } 328 return setTo.Interface() 329 } 330 331 for _, test := range tests { 332 msg := &pb2.DefaultsMessage{} 333 name := test.ext.Name 334 335 // Check the initial value. 336 checkVal(t, name+"/initial", test, msg, test.def) 337 338 // Set the per-type value and check value. 339 if err := proto.SetExtension(msg, test.ext, setTo(test)); err != nil { 340 t.Errorf("%s: SetExtension(): %v", name, err) 341 continue 342 } 343 checkVal(t, name+"/set", test, msg, test.want) 344 345 // Set and check the value. 346 proto.ClearExtension(msg, test.ext) 347 checkVal(t, name+"/cleared", test, msg, test.def) 348 } 349} 350 351func TestNilMessage(t *testing.T) { 352 name := "nil interface" 353 if got, err := proto.GetExtension(nil, pb2.E_Ext_More); err == nil { 354 t.Errorf("%s: got %T %v, expected to fail", name, got, got) 355 } else if !strings.Contains(err.Error(), "extendable") { 356 t.Errorf("%s: got error %v, expected not-extendable error", name, err) 357 } 358 359 // Regression tests: all functions of the Extension API 360 // used to panic when passed (*M)(nil), where M is a concrete message 361 // type. Now they handle this gracefully as a no-op or reported error. 362 var nilMsg *pb2.MyMessage 363 desc := pb2.E_Ext_More 364 365 isNotExtendable := func(err error) bool { 366 return strings.Contains(fmt.Sprint(err), "not an extendable") 367 } 368 369 if proto.HasExtension(nilMsg, desc) { 370 t.Error("HasExtension(nil) = true") 371 } 372 373 if _, err := proto.GetExtensions(nilMsg, []*proto.ExtensionDesc{desc}); !isNotExtendable(err) { 374 t.Errorf("GetExtensions(nil) = %q (wrong error)", err) 375 } 376 377 if _, err := proto.ExtensionDescs(nilMsg); !isNotExtendable(err) { 378 t.Errorf("ExtensionDescs(nil) = %q (wrong error)", err) 379 } 380 381 if err := proto.SetExtension(nilMsg, desc, nil); !isNotExtendable(err) { 382 t.Errorf("SetExtension(nil) = %q (wrong error)", err) 383 } 384 385 proto.ClearExtension(nilMsg, desc) // no-op 386 proto.ClearAllExtensions(nilMsg) // no-op 387} 388 389func TestExtensionsRoundTrip(t *testing.T) { 390 msg := &pb2.MyMessage{} 391 ext1 := &pb2.Ext{ 392 Data: proto.String("hi"), 393 } 394 ext2 := &pb2.Ext{ 395 Data: proto.String("there"), 396 } 397 exists := proto.HasExtension(msg, pb2.E_Ext_More) 398 if exists { 399 t.Error("Extension More present unexpectedly") 400 } 401 if err := proto.SetExtension(msg, pb2.E_Ext_More, ext1); err != nil { 402 t.Error(err) 403 } 404 if err := proto.SetExtension(msg, pb2.E_Ext_More, ext2); err != nil { 405 t.Error(err) 406 } 407 e, err := proto.GetExtension(msg, pb2.E_Ext_More) 408 if err != nil { 409 t.Error(err) 410 } 411 x, ok := e.(*pb2.Ext) 412 if !ok { 413 t.Errorf("e has type %T, expected test_proto.Ext", e) 414 } else if *x.Data != "there" { 415 t.Errorf("SetExtension failed to overwrite, got %+v, not 'there'", x) 416 } 417 proto.ClearExtension(msg, pb2.E_Ext_More) 418 if _, err = proto.GetExtension(msg, pb2.E_Ext_More); err != proto.ErrMissingExtension { 419 t.Errorf("got %v, expected ErrMissingExtension", e) 420 } 421 if err := proto.SetExtension(msg, pb2.E_Ext_More, 12); err == nil { 422 t.Error("expected some sort of type mismatch error, got nil") 423 } 424} 425 426func TestNilExtension(t *testing.T) { 427 msg := &pb2.MyMessage{ 428 Count: proto.Int32(1), 429 } 430 if err := proto.SetExtension(msg, pb2.E_Ext_Text, proto.String("hello")); err != nil { 431 t.Fatal(err) 432 } 433 if err := proto.SetExtension(msg, pb2.E_Ext_More, (*pb2.Ext)(nil)); err == nil { 434 t.Error("expected SetExtension to fail due to a nil extension") 435 } else if want := fmt.Sprintf("proto: SetExtension called with nil value of type %T", new(pb2.Ext)); err.Error() != want { 436 t.Errorf("expected error %v, got %v", want, err) 437 } 438 // Note: if the behavior of Marshal is ever changed to ignore nil extensions, update 439 // this test to verify that E_Ext_Text is properly propagated through marshal->unmarshal. 440} 441 442func TestMarshalUnmarshalRepeatedExtension(t *testing.T) { 443 // Add a repeated extension to the result. 444 tests := []struct { 445 name string 446 ext []*pb2.ComplexExtension 447 }{ 448 { 449 "two fields", 450 []*pb2.ComplexExtension{ 451 {First: proto.Int32(7)}, 452 {Second: proto.Int32(11)}, 453 }, 454 }, 455 { 456 "repeated field", 457 []*pb2.ComplexExtension{ 458 {Third: []int32{1000}}, 459 {Third: []int32{2000}}, 460 }, 461 }, 462 { 463 "two fields and repeated field", 464 []*pb2.ComplexExtension{ 465 {Third: []int32{1000}}, 466 {First: proto.Int32(9)}, 467 {Second: proto.Int32(21)}, 468 {Third: []int32{2000}}, 469 }, 470 }, 471 } 472 for _, test := range tests { 473 // Marshal message with a repeated extension. 474 msg1 := new(pb2.OtherMessage) 475 err := proto.SetExtension(msg1, pb2.E_RComplex, test.ext) 476 if err != nil { 477 t.Fatalf("[%s] Error setting extension: %v", test.name, err) 478 } 479 b, err := proto.Marshal(msg1) 480 if err != nil { 481 t.Fatalf("[%s] Error marshaling message: %v", test.name, err) 482 } 483 484 // Unmarshal and read the merged proto. 485 msg2 := new(pb2.OtherMessage) 486 err = proto.Unmarshal(b, msg2) 487 if err != nil { 488 t.Fatalf("[%s] Error unmarshaling message: %v", test.name, err) 489 } 490 e, err := proto.GetExtension(msg2, pb2.E_RComplex) 491 if err != nil { 492 t.Fatalf("[%s] Error getting extension: %v", test.name, err) 493 } 494 ext := e.([]*pb2.ComplexExtension) 495 if ext == nil { 496 t.Fatalf("[%s] Invalid extension", test.name) 497 } 498 if len(ext) != len(test.ext) { 499 t.Errorf("[%s] Wrong length of ComplexExtension: got: %v want: %v\n", test.name, len(ext), len(test.ext)) 500 } 501 for i := range test.ext { 502 if !proto.Equal(ext[i], test.ext[i]) { 503 t.Errorf("[%s] Wrong value for ComplexExtension[%d]: got: %v want: %v\n", test.name, i, ext[i], test.ext[i]) 504 } 505 } 506 } 507} 508 509func TestUnmarshalRepeatingNonRepeatedExtension(t *testing.T) { 510 // We may see multiple instances of the same extension in the wire 511 // format. For example, the proto compiler may encode custom options in 512 // this way. Here, we verify that we merge the extensions together. 513 tests := []struct { 514 name string 515 ext []*pb2.ComplexExtension 516 }{ 517 { 518 "two fields", 519 []*pb2.ComplexExtension{ 520 {First: proto.Int32(7)}, 521 {Second: proto.Int32(11)}, 522 }, 523 }, 524 { 525 "repeated field", 526 []*pb2.ComplexExtension{ 527 {Third: []int32{1000}}, 528 {Third: []int32{2000}}, 529 }, 530 }, 531 { 532 "two fields and repeated field", 533 []*pb2.ComplexExtension{ 534 {Third: []int32{1000}}, 535 {First: proto.Int32(9)}, 536 {Second: proto.Int32(21)}, 537 {Third: []int32{2000}}, 538 }, 539 }, 540 } 541 for _, test := range tests { 542 var buf bytes.Buffer 543 var want pb2.ComplexExtension 544 545 // Generate a serialized representation of a repeated extension 546 // by catenating bytes together. 547 for i, e := range test.ext { 548 // Merge to create the wanted proto. 549 proto.Merge(&want, e) 550 551 // serialize the message 552 msg := new(pb2.OtherMessage) 553 err := proto.SetExtension(msg, pb2.E_Complex, e) 554 if err != nil { 555 t.Fatalf("[%s] Error setting extension %d: %v", test.name, i, err) 556 } 557 b, err := proto.Marshal(msg) 558 if err != nil { 559 t.Fatalf("[%s] Error marshaling message %d: %v", test.name, i, err) 560 } 561 buf.Write(b) 562 } 563 564 // Unmarshal and read the merged proto. 565 msg2 := new(pb2.OtherMessage) 566 err := proto.Unmarshal(buf.Bytes(), msg2) 567 if err != nil { 568 t.Fatalf("[%s] Error unmarshaling message: %v", test.name, err) 569 } 570 e, err := proto.GetExtension(msg2, pb2.E_Complex) 571 if err != nil { 572 t.Fatalf("[%s] Error getting extension: %v", test.name, err) 573 } 574 ext := e.(*pb2.ComplexExtension) 575 if ext == nil { 576 t.Fatalf("[%s] Invalid extension", test.name) 577 } 578 if !proto.Equal(ext, &want) { 579 t.Errorf("[%s] Wrong value for ComplexExtension: got: %s want: %s\n", test.name, ext, &want) 580 } 581 } 582} 583 584func TestClearAllExtensions(t *testing.T) { 585 // unregistered extension 586 desc := &proto.ExtensionDesc{ 587 ExtendedType: (*pb2.MyMessage)(nil), 588 ExtensionType: (*bool)(nil), 589 Field: 101010100, 590 Name: "emptyextension", 591 Tag: "varint,0,opt", 592 } 593 m := &pb2.MyMessage{} 594 if proto.HasExtension(m, desc) { 595 t.Errorf("proto.HasExtension(%s): got true, want false", proto.MarshalTextString(m)) 596 } 597 if err := proto.SetExtension(m, desc, proto.Bool(true)); err != nil { 598 t.Errorf("proto.SetExtension(m, desc, true): got error %q, want nil", err) 599 } 600 if !proto.HasExtension(m, desc) { 601 t.Errorf("proto.HasExtension(%s): got false, want true", proto.MarshalTextString(m)) 602 } 603 proto.ClearAllExtensions(m) 604 if proto.HasExtension(m, desc) { 605 t.Errorf("proto.HasExtension(%s): got true, want false", proto.MarshalTextString(m)) 606 } 607} 608 609func TestMarshalRace(t *testing.T) { 610 ext := &pb2.Ext{} 611 m := &pb2.MyMessage{Count: proto.Int32(4)} 612 if err := proto.SetExtension(m, pb2.E_Ext_More, ext); err != nil { 613 t.Fatalf("proto.SetExtension(m, desc, true): got error %q, want nil", err) 614 } 615 616 b, err := proto.Marshal(m) 617 if err != nil { 618 t.Fatalf("Could not marshal message: %v", err) 619 } 620 if err := proto.Unmarshal(b, m); err != nil { 621 t.Fatalf("Could not unmarshal message: %v", err) 622 } 623 // after Unmarshal, the extension is in undecoded form. 624 // GetExtension will decode it lazily. Make sure this does 625 // not race against Marshal. 626 627 wg := sync.WaitGroup{} 628 errs := make(chan error, 3) 629 for n := 3; n > 0; n-- { 630 wg.Add(1) 631 go func() { 632 defer wg.Done() 633 _, err := proto.Marshal(m) 634 errs <- err 635 }() 636 } 637 wg.Wait() 638 close(errs) 639 640 for err = range errs { 641 if err != nil { 642 t.Fatal(err) 643 } 644 } 645} 646