1// Copyright 2014 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package proto_test
6
7import (
8	"bytes"
9	"fmt"
10	"reflect"
11	"sort"
12	"strings"
13	"sync"
14	"testing"
15
16	"github.com/golang/protobuf/proto"
17
18	pb2 "github.com/golang/protobuf/internal/testprotos/proto2_proto"
19)
20
21func TestGetExtensionsWithMissingExtensions(t *testing.T) {
22	msg := &pb2.MyMessage{}
23	ext1 := &pb2.Ext{}
24	if err := proto.SetExtension(msg, pb2.E_Ext_More, ext1); err != nil {
25		t.Fatalf("Could not set ext1: %s", err)
26	}
27	exts, err := proto.GetExtensions(msg, []*proto.ExtensionDesc{
28		pb2.E_Ext_More,
29		pb2.E_Ext_Text,
30	})
31	if err != nil {
32		t.Fatalf("GetExtensions() failed: %s", err)
33	}
34	if exts[0] != ext1 {
35		t.Errorf("ext1 not in returned extensions: %T %v", exts[0], exts[0])
36	}
37	if exts[1] != nil {
38		t.Errorf("ext2 in returned extensions: %T %v", exts[1], exts[1])
39	}
40}
41
42func TestGetExtensionForIncompleteDesc(t *testing.T) {
43	msg := &pb2.MyMessage{Count: proto.Int32(0)}
44	extdesc1 := &proto.ExtensionDesc{
45		ExtendedType:  (*pb2.MyMessage)(nil),
46		ExtensionType: (*bool)(nil),
47		Field:         123456789,
48		Name:          "a.b",
49		Tag:           "varint,123456789,opt",
50	}
51	ext1 := proto.Bool(true)
52	if err := proto.SetExtension(msg, extdesc1, ext1); err != nil {
53		t.Fatalf("Could not set ext1: %s", err)
54	}
55	extdesc2 := &proto.ExtensionDesc{
56		ExtendedType:  (*pb2.MyMessage)(nil),
57		ExtensionType: ([]byte)(nil),
58		Field:         123456790,
59		Name:          "a.c",
60		Tag:           "bytes,123456790,opt",
61	}
62	ext2 := []byte{0, 1, 2, 3, 4, 5, 6, 7}
63	if err := proto.SetExtension(msg, extdesc2, ext2); err != nil {
64		t.Fatalf("Could not set ext2: %s", err)
65	}
66	extdesc3 := &proto.ExtensionDesc{
67		ExtendedType:  (*pb2.MyMessage)(nil),
68		ExtensionType: (*pb2.Ext)(nil),
69		Field:         123456791,
70		Name:          "a.d",
71		Tag:           "bytes,123456791,opt",
72	}
73	ext3 := &pb2.Ext{Data: proto.String("foo")}
74	if err := proto.SetExtension(msg, extdesc3, ext3); err != nil {
75		t.Fatalf("Could not set ext3: %s", err)
76	}
77
78	b, err := proto.Marshal(msg)
79	if err != nil {
80		t.Fatalf("Could not marshal msg: %v", err)
81	}
82	if err := proto.Unmarshal(b, msg); err != nil {
83		t.Fatalf("Could not unmarshal into msg: %v", err)
84	}
85
86	var expected proto.Buffer
87	if err := expected.EncodeVarint(uint64((extdesc1.Field << 3) | proto.WireVarint)); err != nil {
88		t.Fatalf("failed to compute expected prefix for ext1: %s", err)
89	}
90	if err := expected.EncodeVarint(1 /* bool true */); err != nil {
91		t.Fatalf("failed to compute expected value for ext1: %s", err)
92	}
93
94	if b, err := proto.GetExtension(msg, &proto.ExtensionDesc{Field: extdesc1.Field}); err != nil {
95		t.Fatalf("Failed to get raw value for ext1: %s", err)
96	} else if !reflect.DeepEqual(b, expected.Bytes()) {
97		t.Fatalf("Raw value for ext1: got %v, want %v", b, expected.Bytes())
98	}
99
100	expected = proto.Buffer{} // reset
101	if err := expected.EncodeVarint(uint64((extdesc2.Field << 3) | proto.WireBytes)); err != nil {
102		t.Fatalf("failed to compute expected prefix for ext2: %s", err)
103	}
104	if err := expected.EncodeRawBytes(ext2); err != nil {
105		t.Fatalf("failed to compute expected value for ext2: %s", err)
106	}
107
108	if b, err := proto.GetExtension(msg, &proto.ExtensionDesc{Field: extdesc2.Field}); err != nil {
109		t.Fatalf("Failed to get raw value for ext2: %s", err)
110	} else if !reflect.DeepEqual(b, expected.Bytes()) {
111		t.Fatalf("Raw value for ext2: got %v, want %v", b, expected.Bytes())
112	}
113
114	expected = proto.Buffer{} // reset
115	if err := expected.EncodeVarint(uint64((extdesc3.Field << 3) | proto.WireBytes)); err != nil {
116		t.Fatalf("failed to compute expected prefix for ext3: %s", err)
117	}
118	if b, err := proto.Marshal(ext3); err != nil {
119		t.Fatalf("failed to compute expected value for ext3: %s", err)
120	} else if err := expected.EncodeRawBytes(b); err != nil {
121		t.Fatalf("failed to compute expected value for ext3: %s", err)
122	}
123
124	if b, err := proto.GetExtension(msg, &proto.ExtensionDesc{Field: extdesc3.Field}); err != nil {
125		t.Fatalf("Failed to get raw value for ext3: %s", err)
126	} else if !reflect.DeepEqual(b, expected.Bytes()) {
127		t.Fatalf("Raw value for ext3: got %v, want %v", b, expected.Bytes())
128	}
129}
130
131func TestExtensionDescsWithUnregisteredExtensions(t *testing.T) {
132	msg := &pb2.MyMessage{Count: proto.Int32(0)}
133	extdesc1 := pb2.E_Ext_More
134	if descs, err := proto.ExtensionDescs(msg); len(descs) != 0 || err != nil {
135		t.Errorf("proto.ExtensionDescs: got %d descs, error %v; want 0, nil", len(descs), err)
136	}
137
138	ext1 := &pb2.Ext{}
139	if err := proto.SetExtension(msg, extdesc1, ext1); err != nil {
140		t.Fatalf("Could not set ext1: %s", err)
141	}
142	extdesc2 := &proto.ExtensionDesc{
143		ExtendedType:  (*pb2.MyMessage)(nil),
144		ExtensionType: (*bool)(nil),
145		Field:         123456789,
146		Name:          "a.b",
147		Tag:           "varint,123456789,opt",
148	}
149	ext2 := proto.Bool(false)
150	if err := proto.SetExtension(msg, extdesc2, ext2); err != nil {
151		t.Fatalf("Could not set ext2: %s", err)
152	}
153
154	b, err := proto.Marshal(msg)
155	if err != nil {
156		t.Fatalf("Could not marshal msg: %v", err)
157	}
158	if err := proto.Unmarshal(b, msg); err != nil {
159		t.Fatalf("Could not unmarshal into msg: %v", err)
160	}
161
162	descs, err := proto.ExtensionDescs(msg)
163	if err != nil {
164		t.Fatalf("proto.ExtensionDescs: got error %v", err)
165	}
166	sortExtDescs(descs)
167	wantDescs := []*proto.ExtensionDesc{extdesc1, {Field: extdesc2.Field}}
168	if !reflect.DeepEqual(descs, wantDescs) {
169		t.Errorf("proto.ExtensionDescs(msg) sorted extension ids: got %+v, want %+v", descs, wantDescs)
170	}
171}
172
173type ExtensionDescSlice []*proto.ExtensionDesc
174
175func (s ExtensionDescSlice) Len() int           { return len(s) }
176func (s ExtensionDescSlice) Less(i, j int) bool { return s[i].Field < s[j].Field }
177func (s ExtensionDescSlice) Swap(i, j int)      { s[i], s[j] = s[j], s[i] }
178
179func sortExtDescs(s []*proto.ExtensionDesc) {
180	sort.Sort(ExtensionDescSlice(s))
181}
182
183func TestGetExtensionStability(t *testing.T) {
184	check := func(m *pb2.MyMessage) bool {
185		ext1, err := proto.GetExtension(m, pb2.E_Ext_More)
186		if err != nil {
187			t.Fatalf("GetExtension() failed: %s", err)
188		}
189		ext2, err := proto.GetExtension(m, pb2.E_Ext_More)
190		if err != nil {
191			t.Fatalf("GetExtension() failed: %s", err)
192		}
193		return ext1 == ext2
194	}
195	msg := &pb2.MyMessage{Count: proto.Int32(4)}
196	ext0 := &pb2.Ext{}
197	if err := proto.SetExtension(msg, pb2.E_Ext_More, ext0); err != nil {
198		t.Fatalf("Could not set ext1: %s", ext0)
199	}
200	if !check(msg) {
201		t.Errorf("GetExtension() not stable before marshaling")
202	}
203	bb, err := proto.Marshal(msg)
204	if err != nil {
205		t.Fatalf("Marshal() failed: %s", err)
206	}
207	msg1 := &pb2.MyMessage{}
208	err = proto.Unmarshal(bb, msg1)
209	if err != nil {
210		t.Fatalf("Unmarshal() failed: %s", err)
211	}
212	if !check(msg1) {
213		t.Errorf("GetExtension() not stable after unmarshaling")
214	}
215}
216
217func TestGetExtensionDefaults(t *testing.T) {
218	var setFloat64 float64 = 1
219	var setFloat32 float32 = 2
220	var setInt32 int32 = 3
221	var setInt64 int64 = 4
222	var setUint32 uint32 = 5
223	var setUint64 uint64 = 6
224	var setBool = true
225	var setBool2 = false
226	var setString = "Goodnight string"
227	var setBytes = []byte("Goodnight bytes")
228	var setEnum = pb2.DefaultsMessage_TWO
229
230	type testcase struct {
231		ext  *proto.ExtensionDesc // Extension we are testing.
232		want interface{}          // Expected value of extension, or nil (meaning that GetExtension will fail).
233		def  interface{}          // Expected value of extension after ClearExtension().
234	}
235	tests := []testcase{
236		{pb2.E_NoDefaultDouble, setFloat64, nil},
237		{pb2.E_NoDefaultFloat, setFloat32, nil},
238		{pb2.E_NoDefaultInt32, setInt32, nil},
239		{pb2.E_NoDefaultInt64, setInt64, nil},
240		{pb2.E_NoDefaultUint32, setUint32, nil},
241		{pb2.E_NoDefaultUint64, setUint64, nil},
242		{pb2.E_NoDefaultSint32, setInt32, nil},
243		{pb2.E_NoDefaultSint64, setInt64, nil},
244		{pb2.E_NoDefaultFixed32, setUint32, nil},
245		{pb2.E_NoDefaultFixed64, setUint64, nil},
246		{pb2.E_NoDefaultSfixed32, setInt32, nil},
247		{pb2.E_NoDefaultSfixed64, setInt64, nil},
248		{pb2.E_NoDefaultBool, setBool, nil},
249		{pb2.E_NoDefaultBool, setBool2, nil},
250		{pb2.E_NoDefaultString, setString, nil},
251		{pb2.E_NoDefaultBytes, setBytes, nil},
252		{pb2.E_NoDefaultEnum, setEnum, nil},
253		{pb2.E_DefaultDouble, setFloat64, float64(3.1415)},
254		{pb2.E_DefaultFloat, setFloat32, float32(3.14)},
255		{pb2.E_DefaultInt32, setInt32, int32(42)},
256		{pb2.E_DefaultInt64, setInt64, int64(43)},
257		{pb2.E_DefaultUint32, setUint32, uint32(44)},
258		{pb2.E_DefaultUint64, setUint64, uint64(45)},
259		{pb2.E_DefaultSint32, setInt32, int32(46)},
260		{pb2.E_DefaultSint64, setInt64, int64(47)},
261		{pb2.E_DefaultFixed32, setUint32, uint32(48)},
262		{pb2.E_DefaultFixed64, setUint64, uint64(49)},
263		{pb2.E_DefaultSfixed32, setInt32, int32(50)},
264		{pb2.E_DefaultSfixed64, setInt64, int64(51)},
265		{pb2.E_DefaultBool, setBool, true},
266		{pb2.E_DefaultBool, setBool2, true},
267		{pb2.E_DefaultString, setString, "Hello, string,def=foo"},
268		{pb2.E_DefaultBytes, setBytes, []byte("Hello, bytes")},
269		{pb2.E_DefaultEnum, setEnum, pb2.DefaultsMessage_ONE},
270	}
271
272	checkVal := func(t *testing.T, name string, test testcase, msg *pb2.DefaultsMessage, valWant interface{}) {
273		t.Run(name, func(t *testing.T) {
274			val, err := proto.GetExtension(msg, test.ext)
275			if err != nil {
276				if valWant != nil {
277					t.Errorf("GetExtension(): %s", err)
278					return
279				}
280				if want := proto.ErrMissingExtension; err != want {
281					t.Errorf("Unexpected error: got %v, want %v", err, want)
282					return
283				}
284				return
285			}
286
287			// All proto2 extension values are either a pointer to a value or a slice of values.
288			ty := reflect.TypeOf(val)
289			tyWant := reflect.TypeOf(test.ext.ExtensionType)
290			if got, want := ty, tyWant; got != want {
291				t.Errorf("unexpected reflect.TypeOf(): got %v want %v", got, want)
292				return
293			}
294			tye := ty.Elem()
295			tyeWant := tyWant.Elem()
296			if got, want := tye, tyeWant; got != want {
297				t.Errorf("unexpected reflect.TypeOf().Elem(): got %v want %v", got, want)
298				return
299			}
300
301			// Check the name of the type of the value.
302			// If it is an enum it will be type int32 with the name of the enum.
303			if got, want := tye.Name(), tye.Name(); got != want {
304				t.Errorf("unexpected reflect.TypeOf().Elem().Name(): got %v want %v", got, want)
305				return
306			}
307
308			// Check that value is what we expect.
309			// If we have a pointer in val, get the value it points to.
310			valExp := val
311			if ty.Kind() == reflect.Ptr {
312				valExp = reflect.ValueOf(val).Elem().Interface()
313			}
314			if got, want := valExp, valWant; !reflect.DeepEqual(got, want) {
315				t.Errorf("unexpected reflect.DeepEqual(): got %v want %v", got, want)
316				return
317			}
318		})
319	}
320
321	setTo := func(test testcase) interface{} {
322		setTo := reflect.ValueOf(test.want)
323		if typ := reflect.TypeOf(test.ext.ExtensionType); typ.Kind() == reflect.Ptr {
324			setTo = reflect.New(typ).Elem()
325			setTo.Set(reflect.New(setTo.Type().Elem()))
326			setTo.Elem().Set(reflect.ValueOf(test.want))
327		}
328		return setTo.Interface()
329	}
330
331	for _, test := range tests {
332		msg := &pb2.DefaultsMessage{}
333		name := test.ext.Name
334
335		// Check the initial value.
336		checkVal(t, name+"/initial", test, msg, test.def)
337
338		// Set the per-type value and check value.
339		if err := proto.SetExtension(msg, test.ext, setTo(test)); err != nil {
340			t.Errorf("%s: SetExtension(): %v", name, err)
341			continue
342		}
343		checkVal(t, name+"/set", test, msg, test.want)
344
345		// Set and check the value.
346		proto.ClearExtension(msg, test.ext)
347		checkVal(t, name+"/cleared", test, msg, test.def)
348	}
349}
350
351func TestNilMessage(t *testing.T) {
352	name := "nil interface"
353	if got, err := proto.GetExtension(nil, pb2.E_Ext_More); err == nil {
354		t.Errorf("%s: got %T %v, expected to fail", name, got, got)
355	} else if !strings.Contains(err.Error(), "extendable") {
356		t.Errorf("%s: got error %v, expected not-extendable error", name, err)
357	}
358
359	// Regression tests: all functions of the Extension API
360	// used to panic when passed (*M)(nil), where M is a concrete message
361	// type.  Now they handle this gracefully as a no-op or reported error.
362	var nilMsg *pb2.MyMessage
363	desc := pb2.E_Ext_More
364
365	isNotExtendable := func(err error) bool {
366		return strings.Contains(fmt.Sprint(err), "not an extendable")
367	}
368
369	if proto.HasExtension(nilMsg, desc) {
370		t.Error("HasExtension(nil) = true")
371	}
372
373	if _, err := proto.GetExtensions(nilMsg, []*proto.ExtensionDesc{desc}); !isNotExtendable(err) {
374		t.Errorf("GetExtensions(nil) = %q (wrong error)", err)
375	}
376
377	if _, err := proto.ExtensionDescs(nilMsg); !isNotExtendable(err) {
378		t.Errorf("ExtensionDescs(nil) = %q (wrong error)", err)
379	}
380
381	if err := proto.SetExtension(nilMsg, desc, nil); !isNotExtendable(err) {
382		t.Errorf("SetExtension(nil) = %q (wrong error)", err)
383	}
384
385	proto.ClearExtension(nilMsg, desc) // no-op
386	proto.ClearAllExtensions(nilMsg)   // no-op
387}
388
389func TestExtensionsRoundTrip(t *testing.T) {
390	msg := &pb2.MyMessage{}
391	ext1 := &pb2.Ext{
392		Data: proto.String("hi"),
393	}
394	ext2 := &pb2.Ext{
395		Data: proto.String("there"),
396	}
397	exists := proto.HasExtension(msg, pb2.E_Ext_More)
398	if exists {
399		t.Error("Extension More present unexpectedly")
400	}
401	if err := proto.SetExtension(msg, pb2.E_Ext_More, ext1); err != nil {
402		t.Error(err)
403	}
404	if err := proto.SetExtension(msg, pb2.E_Ext_More, ext2); err != nil {
405		t.Error(err)
406	}
407	e, err := proto.GetExtension(msg, pb2.E_Ext_More)
408	if err != nil {
409		t.Error(err)
410	}
411	x, ok := e.(*pb2.Ext)
412	if !ok {
413		t.Errorf("e has type %T, expected test_proto.Ext", e)
414	} else if *x.Data != "there" {
415		t.Errorf("SetExtension failed to overwrite, got %+v, not 'there'", x)
416	}
417	proto.ClearExtension(msg, pb2.E_Ext_More)
418	if _, err = proto.GetExtension(msg, pb2.E_Ext_More); err != proto.ErrMissingExtension {
419		t.Errorf("got %v, expected ErrMissingExtension", e)
420	}
421	if err := proto.SetExtension(msg, pb2.E_Ext_More, 12); err == nil {
422		t.Error("expected some sort of type mismatch error, got nil")
423	}
424}
425
426func TestNilExtension(t *testing.T) {
427	msg := &pb2.MyMessage{
428		Count: proto.Int32(1),
429	}
430	if err := proto.SetExtension(msg, pb2.E_Ext_Text, proto.String("hello")); err != nil {
431		t.Fatal(err)
432	}
433	if err := proto.SetExtension(msg, pb2.E_Ext_More, (*pb2.Ext)(nil)); err == nil {
434		t.Error("expected SetExtension to fail due to a nil extension")
435	} else if want := fmt.Sprintf("proto: SetExtension called with nil value of type %T", new(pb2.Ext)); err.Error() != want {
436		t.Errorf("expected error %v, got %v", want, err)
437	}
438	// Note: if the behavior of Marshal is ever changed to ignore nil extensions, update
439	// this test to verify that E_Ext_Text is properly propagated through marshal->unmarshal.
440}
441
442func TestMarshalUnmarshalRepeatedExtension(t *testing.T) {
443	// Add a repeated extension to the result.
444	tests := []struct {
445		name string
446		ext  []*pb2.ComplexExtension
447	}{
448		{
449			"two fields",
450			[]*pb2.ComplexExtension{
451				{First: proto.Int32(7)},
452				{Second: proto.Int32(11)},
453			},
454		},
455		{
456			"repeated field",
457			[]*pb2.ComplexExtension{
458				{Third: []int32{1000}},
459				{Third: []int32{2000}},
460			},
461		},
462		{
463			"two fields and repeated field",
464			[]*pb2.ComplexExtension{
465				{Third: []int32{1000}},
466				{First: proto.Int32(9)},
467				{Second: proto.Int32(21)},
468				{Third: []int32{2000}},
469			},
470		},
471	}
472	for _, test := range tests {
473		// Marshal message with a repeated extension.
474		msg1 := new(pb2.OtherMessage)
475		err := proto.SetExtension(msg1, pb2.E_RComplex, test.ext)
476		if err != nil {
477			t.Fatalf("[%s] Error setting extension: %v", test.name, err)
478		}
479		b, err := proto.Marshal(msg1)
480		if err != nil {
481			t.Fatalf("[%s] Error marshaling message: %v", test.name, err)
482		}
483
484		// Unmarshal and read the merged proto.
485		msg2 := new(pb2.OtherMessage)
486		err = proto.Unmarshal(b, msg2)
487		if err != nil {
488			t.Fatalf("[%s] Error unmarshaling message: %v", test.name, err)
489		}
490		e, err := proto.GetExtension(msg2, pb2.E_RComplex)
491		if err != nil {
492			t.Fatalf("[%s] Error getting extension: %v", test.name, err)
493		}
494		ext := e.([]*pb2.ComplexExtension)
495		if ext == nil {
496			t.Fatalf("[%s] Invalid extension", test.name)
497		}
498		if len(ext) != len(test.ext) {
499			t.Errorf("[%s] Wrong length of ComplexExtension: got: %v want: %v\n", test.name, len(ext), len(test.ext))
500		}
501		for i := range test.ext {
502			if !proto.Equal(ext[i], test.ext[i]) {
503				t.Errorf("[%s] Wrong value for ComplexExtension[%d]: got: %v want: %v\n", test.name, i, ext[i], test.ext[i])
504			}
505		}
506	}
507}
508
509func TestUnmarshalRepeatingNonRepeatedExtension(t *testing.T) {
510	// We may see multiple instances of the same extension in the wire
511	// format. For example, the proto compiler may encode custom options in
512	// this way. Here, we verify that we merge the extensions together.
513	tests := []struct {
514		name string
515		ext  []*pb2.ComplexExtension
516	}{
517		{
518			"two fields",
519			[]*pb2.ComplexExtension{
520				{First: proto.Int32(7)},
521				{Second: proto.Int32(11)},
522			},
523		},
524		{
525			"repeated field",
526			[]*pb2.ComplexExtension{
527				{Third: []int32{1000}},
528				{Third: []int32{2000}},
529			},
530		},
531		{
532			"two fields and repeated field",
533			[]*pb2.ComplexExtension{
534				{Third: []int32{1000}},
535				{First: proto.Int32(9)},
536				{Second: proto.Int32(21)},
537				{Third: []int32{2000}},
538			},
539		},
540	}
541	for _, test := range tests {
542		var buf bytes.Buffer
543		var want pb2.ComplexExtension
544
545		// Generate a serialized representation of a repeated extension
546		// by catenating bytes together.
547		for i, e := range test.ext {
548			// Merge to create the wanted proto.
549			proto.Merge(&want, e)
550
551			// serialize the message
552			msg := new(pb2.OtherMessage)
553			err := proto.SetExtension(msg, pb2.E_Complex, e)
554			if err != nil {
555				t.Fatalf("[%s] Error setting extension %d: %v", test.name, i, err)
556			}
557			b, err := proto.Marshal(msg)
558			if err != nil {
559				t.Fatalf("[%s] Error marshaling message %d: %v", test.name, i, err)
560			}
561			buf.Write(b)
562		}
563
564		// Unmarshal and read the merged proto.
565		msg2 := new(pb2.OtherMessage)
566		err := proto.Unmarshal(buf.Bytes(), msg2)
567		if err != nil {
568			t.Fatalf("[%s] Error unmarshaling message: %v", test.name, err)
569		}
570		e, err := proto.GetExtension(msg2, pb2.E_Complex)
571		if err != nil {
572			t.Fatalf("[%s] Error getting extension: %v", test.name, err)
573		}
574		ext := e.(*pb2.ComplexExtension)
575		if ext == nil {
576			t.Fatalf("[%s] Invalid extension", test.name)
577		}
578		if !proto.Equal(ext, &want) {
579			t.Errorf("[%s] Wrong value for ComplexExtension: got: %s want: %s\n", test.name, ext, &want)
580		}
581	}
582}
583
584func TestClearAllExtensions(t *testing.T) {
585	// unregistered extension
586	desc := &proto.ExtensionDesc{
587		ExtendedType:  (*pb2.MyMessage)(nil),
588		ExtensionType: (*bool)(nil),
589		Field:         101010100,
590		Name:          "emptyextension",
591		Tag:           "varint,0,opt",
592	}
593	m := &pb2.MyMessage{}
594	if proto.HasExtension(m, desc) {
595		t.Errorf("proto.HasExtension(%s): got true, want false", proto.MarshalTextString(m))
596	}
597	if err := proto.SetExtension(m, desc, proto.Bool(true)); err != nil {
598		t.Errorf("proto.SetExtension(m, desc, true): got error %q, want nil", err)
599	}
600	if !proto.HasExtension(m, desc) {
601		t.Errorf("proto.HasExtension(%s): got false, want true", proto.MarshalTextString(m))
602	}
603	proto.ClearAllExtensions(m)
604	if proto.HasExtension(m, desc) {
605		t.Errorf("proto.HasExtension(%s): got true, want false", proto.MarshalTextString(m))
606	}
607}
608
609func TestMarshalRace(t *testing.T) {
610	ext := &pb2.Ext{}
611	m := &pb2.MyMessage{Count: proto.Int32(4)}
612	if err := proto.SetExtension(m, pb2.E_Ext_More, ext); err != nil {
613		t.Fatalf("proto.SetExtension(m, desc, true): got error %q, want nil", err)
614	}
615
616	b, err := proto.Marshal(m)
617	if err != nil {
618		t.Fatalf("Could not marshal message: %v", err)
619	}
620	if err := proto.Unmarshal(b, m); err != nil {
621		t.Fatalf("Could not unmarshal message: %v", err)
622	}
623	// after Unmarshal, the extension is in undecoded form.
624	// GetExtension will decode it lazily. Make sure this does
625	// not race against Marshal.
626
627	wg := sync.WaitGroup{}
628	errs := make(chan error, 3)
629	for n := 3; n > 0; n-- {
630		wg.Add(1)
631		go func() {
632			defer wg.Done()
633			_, err := proto.Marshal(m)
634			errs <- err
635		}()
636	}
637	wg.Wait()
638	close(errs)
639
640	for err = range errs {
641		if err != nil {
642			t.Fatal(err)
643		}
644	}
645}
646