1// Go support for Protocol Buffers - Google's data interchange format
2//
3// Copyright 2014 The Go Authors.  All rights reserved.
4// https://github.com/golang/protobuf
5//
6// Redistribution and use in source and binary forms, with or without
7// modification, are permitted provided that the following conditions are
8// met:
9//
10//     * Redistributions of source code must retain the above copyright
11// notice, this list of conditions and the following disclaimer.
12//     * Redistributions in binary form must reproduce the above
13// copyright notice, this list of conditions and the following disclaimer
14// in the documentation and/or other materials provided with the
15// distribution.
16//     * Neither the name of Google Inc. nor the names of its
17// contributors may be used to endorse or promote products derived from
18// this software without specific prior written permission.
19//
20// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
21// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
22// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
23// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
24// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
25// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
26// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
27// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
28// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
29// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
30// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
31
32package proto_test
33
34import (
35	"bytes"
36	"fmt"
37	"io"
38	"reflect"
39	"sort"
40	"strings"
41	"sync"
42	"testing"
43
44	"github.com/golang/protobuf/proto"
45	pb "github.com/golang/protobuf/proto/test_proto"
46)
47
48func TestGetExtensionsWithMissingExtensions(t *testing.T) {
49	msg := &pb.MyMessage{}
50	ext1 := &pb.Ext{}
51	if err := proto.SetExtension(msg, pb.E_Ext_More, ext1); err != nil {
52		t.Fatalf("Could not set ext1: %s", err)
53	}
54	exts, err := proto.GetExtensions(msg, []*proto.ExtensionDesc{
55		pb.E_Ext_More,
56		pb.E_Ext_Text,
57	})
58	if err != nil {
59		t.Fatalf("GetExtensions() failed: %s", err)
60	}
61	if exts[0] != ext1 {
62		t.Errorf("ext1 not in returned extensions: %T %v", exts[0], exts[0])
63	}
64	if exts[1] != nil {
65		t.Errorf("ext2 in returned extensions: %T %v", exts[1], exts[1])
66	}
67}
68
69func TestGetExtensionWithEmptyBuffer(t *testing.T) {
70	// Make sure that GetExtension returns an error if its
71	// undecoded buffer is empty.
72	msg := &pb.MyMessage{}
73	proto.SetRawExtension(msg, pb.E_Ext_More.Field, []byte{})
74	_, err := proto.GetExtension(msg, pb.E_Ext_More)
75	if want := io.ErrUnexpectedEOF; err != want {
76		t.Errorf("unexpected error in GetExtension from empty buffer: got %v, want %v", err, want)
77	}
78}
79
80func TestGetExtensionForIncompleteDesc(t *testing.T) {
81	msg := &pb.MyMessage{Count: proto.Int32(0)}
82	extdesc1 := &proto.ExtensionDesc{
83		ExtendedType:  (*pb.MyMessage)(nil),
84		ExtensionType: (*bool)(nil),
85		Field:         123456789,
86		Name:          "a.b",
87		Tag:           "varint,123456789,opt",
88	}
89	ext1 := proto.Bool(true)
90	if err := proto.SetExtension(msg, extdesc1, ext1); err != nil {
91		t.Fatalf("Could not set ext1: %s", err)
92	}
93	extdesc2 := &proto.ExtensionDesc{
94		ExtendedType:  (*pb.MyMessage)(nil),
95		ExtensionType: ([]byte)(nil),
96		Field:         123456790,
97		Name:          "a.c",
98		Tag:           "bytes,123456790,opt",
99	}
100	ext2 := []byte{0, 1, 2, 3, 4, 5, 6, 7}
101	if err := proto.SetExtension(msg, extdesc2, ext2); err != nil {
102		t.Fatalf("Could not set ext2: %s", err)
103	}
104	extdesc3 := &proto.ExtensionDesc{
105		ExtendedType:  (*pb.MyMessage)(nil),
106		ExtensionType: (*pb.Ext)(nil),
107		Field:         123456791,
108		Name:          "a.d",
109		Tag:           "bytes,123456791,opt",
110	}
111	ext3 := &pb.Ext{Data: proto.String("foo")}
112	if err := proto.SetExtension(msg, extdesc3, ext3); err != nil {
113		t.Fatalf("Could not set ext3: %s", err)
114	}
115
116	b, err := proto.Marshal(msg)
117	if err != nil {
118		t.Fatalf("Could not marshal msg: %v", err)
119	}
120	if err := proto.Unmarshal(b, msg); err != nil {
121		t.Fatalf("Could not unmarshal into msg: %v", err)
122	}
123
124	var expected proto.Buffer
125	if err := expected.EncodeVarint(uint64((extdesc1.Field << 3) | proto.WireVarint)); err != nil {
126		t.Fatalf("failed to compute expected prefix for ext1: %s", err)
127	}
128	if err := expected.EncodeVarint(1 /* bool true */); err != nil {
129		t.Fatalf("failed to compute expected value for ext1: %s", err)
130	}
131
132	if b, err := proto.GetExtension(msg, &proto.ExtensionDesc{Field: extdesc1.Field}); err != nil {
133		t.Fatalf("Failed to get raw value for ext1: %s", err)
134	} else if !reflect.DeepEqual(b, expected.Bytes()) {
135		t.Fatalf("Raw value for ext1: got %v, want %v", b, expected.Bytes())
136	}
137
138	expected = proto.Buffer{} // reset
139	if err := expected.EncodeVarint(uint64((extdesc2.Field << 3) | proto.WireBytes)); err != nil {
140		t.Fatalf("failed to compute expected prefix for ext2: %s", err)
141	}
142	if err := expected.EncodeRawBytes(ext2); err != nil {
143		t.Fatalf("failed to compute expected value for ext2: %s", err)
144	}
145
146	if b, err := proto.GetExtension(msg, &proto.ExtensionDesc{Field: extdesc2.Field}); err != nil {
147		t.Fatalf("Failed to get raw value for ext2: %s", err)
148	} else if !reflect.DeepEqual(b, expected.Bytes()) {
149		t.Fatalf("Raw value for ext2: got %v, want %v", b, expected.Bytes())
150	}
151
152	expected = proto.Buffer{} // reset
153	if err := expected.EncodeVarint(uint64((extdesc3.Field << 3) | proto.WireBytes)); err != nil {
154		t.Fatalf("failed to compute expected prefix for ext3: %s", err)
155	}
156	if b, err := proto.Marshal(ext3); err != nil {
157		t.Fatalf("failed to compute expected value for ext3: %s", err)
158	} else if err := expected.EncodeRawBytes(b); err != nil {
159		t.Fatalf("failed to compute expected value for ext3: %s", err)
160	}
161
162	if b, err := proto.GetExtension(msg, &proto.ExtensionDesc{Field: extdesc3.Field}); err != nil {
163		t.Fatalf("Failed to get raw value for ext3: %s", err)
164	} else if !reflect.DeepEqual(b, expected.Bytes()) {
165		t.Fatalf("Raw value for ext3: got %v, want %v", b, expected.Bytes())
166	}
167}
168
169func TestExtensionDescsWithUnregisteredExtensions(t *testing.T) {
170	msg := &pb.MyMessage{Count: proto.Int32(0)}
171	extdesc1 := pb.E_Ext_More
172	if descs, err := proto.ExtensionDescs(msg); len(descs) != 0 || err != nil {
173		t.Errorf("proto.ExtensionDescs: got %d descs, error %v; want 0, nil", len(descs), err)
174	}
175
176	ext1 := &pb.Ext{}
177	if err := proto.SetExtension(msg, extdesc1, ext1); err != nil {
178		t.Fatalf("Could not set ext1: %s", err)
179	}
180	extdesc2 := &proto.ExtensionDesc{
181		ExtendedType:  (*pb.MyMessage)(nil),
182		ExtensionType: (*bool)(nil),
183		Field:         123456789,
184		Name:          "a.b",
185		Tag:           "varint,123456789,opt",
186	}
187	ext2 := proto.Bool(false)
188	if err := proto.SetExtension(msg, extdesc2, ext2); err != nil {
189		t.Fatalf("Could not set ext2: %s", err)
190	}
191
192	b, err := proto.Marshal(msg)
193	if err != nil {
194		t.Fatalf("Could not marshal msg: %v", err)
195	}
196	if err := proto.Unmarshal(b, msg); err != nil {
197		t.Fatalf("Could not unmarshal into msg: %v", err)
198	}
199
200	descs, err := proto.ExtensionDescs(msg)
201	if err != nil {
202		t.Fatalf("proto.ExtensionDescs: got error %v", err)
203	}
204	sortExtDescs(descs)
205	wantDescs := []*proto.ExtensionDesc{extdesc1, {Field: extdesc2.Field}}
206	if !reflect.DeepEqual(descs, wantDescs) {
207		t.Errorf("proto.ExtensionDescs(msg) sorted extension ids: got %+v, want %+v", descs, wantDescs)
208	}
209}
210
211type ExtensionDescSlice []*proto.ExtensionDesc
212
213func (s ExtensionDescSlice) Len() int           { return len(s) }
214func (s ExtensionDescSlice) Less(i, j int) bool { return s[i].Field < s[j].Field }
215func (s ExtensionDescSlice) Swap(i, j int)      { s[i], s[j] = s[j], s[i] }
216
217func sortExtDescs(s []*proto.ExtensionDesc) {
218	sort.Sort(ExtensionDescSlice(s))
219}
220
221func TestGetExtensionStability(t *testing.T) {
222	check := func(m *pb.MyMessage) bool {
223		ext1, err := proto.GetExtension(m, pb.E_Ext_More)
224		if err != nil {
225			t.Fatalf("GetExtension() failed: %s", err)
226		}
227		ext2, err := proto.GetExtension(m, pb.E_Ext_More)
228		if err != nil {
229			t.Fatalf("GetExtension() failed: %s", err)
230		}
231		return ext1 == ext2
232	}
233	msg := &pb.MyMessage{Count: proto.Int32(4)}
234	ext0 := &pb.Ext{}
235	if err := proto.SetExtension(msg, pb.E_Ext_More, ext0); err != nil {
236		t.Fatalf("Could not set ext1: %s", ext0)
237	}
238	if !check(msg) {
239		t.Errorf("GetExtension() not stable before marshaling")
240	}
241	bb, err := proto.Marshal(msg)
242	if err != nil {
243		t.Fatalf("Marshal() failed: %s", err)
244	}
245	msg1 := &pb.MyMessage{}
246	err = proto.Unmarshal(bb, msg1)
247	if err != nil {
248		t.Fatalf("Unmarshal() failed: %s", err)
249	}
250	if !check(msg1) {
251		t.Errorf("GetExtension() not stable after unmarshaling")
252	}
253}
254
255func TestGetExtensionDefaults(t *testing.T) {
256	var setFloat64 float64 = 1
257	var setFloat32 float32 = 2
258	var setInt32 int32 = 3
259	var setInt64 int64 = 4
260	var setUint32 uint32 = 5
261	var setUint64 uint64 = 6
262	var setBool = true
263	var setBool2 = false
264	var setString = "Goodnight string"
265	var setBytes = []byte("Goodnight bytes")
266	var setEnum = pb.DefaultsMessage_TWO
267
268	type testcase struct {
269		ext  *proto.ExtensionDesc // Extension we are testing.
270		want interface{}          // Expected value of extension, or nil (meaning that GetExtension will fail).
271		def  interface{}          // Expected value of extension after ClearExtension().
272	}
273	tests := []testcase{
274		{pb.E_NoDefaultDouble, setFloat64, nil},
275		{pb.E_NoDefaultFloat, setFloat32, nil},
276		{pb.E_NoDefaultInt32, setInt32, nil},
277		{pb.E_NoDefaultInt64, setInt64, nil},
278		{pb.E_NoDefaultUint32, setUint32, nil},
279		{pb.E_NoDefaultUint64, setUint64, nil},
280		{pb.E_NoDefaultSint32, setInt32, nil},
281		{pb.E_NoDefaultSint64, setInt64, nil},
282		{pb.E_NoDefaultFixed32, setUint32, nil},
283		{pb.E_NoDefaultFixed64, setUint64, nil},
284		{pb.E_NoDefaultSfixed32, setInt32, nil},
285		{pb.E_NoDefaultSfixed64, setInt64, nil},
286		{pb.E_NoDefaultBool, setBool, nil},
287		{pb.E_NoDefaultBool, setBool2, nil},
288		{pb.E_NoDefaultString, setString, nil},
289		{pb.E_NoDefaultBytes, setBytes, nil},
290		{pb.E_NoDefaultEnum, setEnum, nil},
291		{pb.E_DefaultDouble, setFloat64, float64(3.1415)},
292		{pb.E_DefaultFloat, setFloat32, float32(3.14)},
293		{pb.E_DefaultInt32, setInt32, int32(42)},
294		{pb.E_DefaultInt64, setInt64, int64(43)},
295		{pb.E_DefaultUint32, setUint32, uint32(44)},
296		{pb.E_DefaultUint64, setUint64, uint64(45)},
297		{pb.E_DefaultSint32, setInt32, int32(46)},
298		{pb.E_DefaultSint64, setInt64, int64(47)},
299		{pb.E_DefaultFixed32, setUint32, uint32(48)},
300		{pb.E_DefaultFixed64, setUint64, uint64(49)},
301		{pb.E_DefaultSfixed32, setInt32, int32(50)},
302		{pb.E_DefaultSfixed64, setInt64, int64(51)},
303		{pb.E_DefaultBool, setBool, true},
304		{pb.E_DefaultBool, setBool2, true},
305		{pb.E_DefaultString, setString, "Hello, string,def=foo"},
306		{pb.E_DefaultBytes, setBytes, []byte("Hello, bytes")},
307		{pb.E_DefaultEnum, setEnum, pb.DefaultsMessage_ONE},
308	}
309
310	checkVal := func(test testcase, msg *pb.DefaultsMessage, valWant interface{}) error {
311		val, err := proto.GetExtension(msg, test.ext)
312		if err != nil {
313			if valWant != nil {
314				return fmt.Errorf("GetExtension(): %s", err)
315			}
316			if want := proto.ErrMissingExtension; err != want {
317				return fmt.Errorf("Unexpected error: got %v, want %v", err, want)
318			}
319			return nil
320		}
321
322		// All proto2 extension values are either a pointer to a value or a slice of values.
323		ty := reflect.TypeOf(val)
324		tyWant := reflect.TypeOf(test.ext.ExtensionType)
325		if got, want := ty, tyWant; got != want {
326			return fmt.Errorf("unexpected reflect.TypeOf(): got %v want %v", got, want)
327		}
328		tye := ty.Elem()
329		tyeWant := tyWant.Elem()
330		if got, want := tye, tyeWant; got != want {
331			return fmt.Errorf("unexpected reflect.TypeOf().Elem(): got %v want %v", got, want)
332		}
333
334		// Check the name of the type of the value.
335		// If it is an enum it will be type int32 with the name of the enum.
336		if got, want := tye.Name(), tye.Name(); got != want {
337			return fmt.Errorf("unexpected reflect.TypeOf().Elem().Name(): got %v want %v", got, want)
338		}
339
340		// Check that value is what we expect.
341		// If we have a pointer in val, get the value it points to.
342		valExp := val
343		if ty.Kind() == reflect.Ptr {
344			valExp = reflect.ValueOf(val).Elem().Interface()
345		}
346		if got, want := valExp, valWant; !reflect.DeepEqual(got, want) {
347			return fmt.Errorf("unexpected reflect.DeepEqual(): got %v want %v", got, want)
348		}
349
350		return nil
351	}
352
353	setTo := func(test testcase) interface{} {
354		setTo := reflect.ValueOf(test.want)
355		if typ := reflect.TypeOf(test.ext.ExtensionType); typ.Kind() == reflect.Ptr {
356			setTo = reflect.New(typ).Elem()
357			setTo.Set(reflect.New(setTo.Type().Elem()))
358			setTo.Elem().Set(reflect.ValueOf(test.want))
359		}
360		return setTo.Interface()
361	}
362
363	for _, test := range tests {
364		msg := &pb.DefaultsMessage{}
365		name := test.ext.Name
366
367		// Check the initial value.
368		if err := checkVal(test, msg, test.def); err != nil {
369			t.Errorf("%s: %v", name, err)
370		}
371
372		// Set the per-type value and check value.
373		name = fmt.Sprintf("%s (set to %T %v)", name, test.want, test.want)
374		if err := proto.SetExtension(msg, test.ext, setTo(test)); err != nil {
375			t.Errorf("%s: SetExtension(): %v", name, err)
376			continue
377		}
378		if err := checkVal(test, msg, test.want); err != nil {
379			t.Errorf("%s: %v", name, err)
380			continue
381		}
382
383		// Set and check the value.
384		name += " (cleared)"
385		proto.ClearExtension(msg, test.ext)
386		if err := checkVal(test, msg, test.def); err != nil {
387			t.Errorf("%s: %v", name, err)
388		}
389	}
390}
391
392func TestNilMessage(t *testing.T) {
393	name := "nil interface"
394	if got, err := proto.GetExtension(nil, pb.E_Ext_More); err == nil {
395		t.Errorf("%s: got %T %v, expected to fail", name, got, got)
396	} else if !strings.Contains(err.Error(), "extendable") {
397		t.Errorf("%s: got error %v, expected not-extendable error", name, err)
398	}
399
400	// Regression tests: all functions of the Extension API
401	// used to panic when passed (*M)(nil), where M is a concrete message
402	// type.  Now they handle this gracefully as a no-op or reported error.
403	var nilMsg *pb.MyMessage
404	desc := pb.E_Ext_More
405
406	isNotExtendable := func(err error) bool {
407		return strings.Contains(fmt.Sprint(err), "not extendable")
408	}
409
410	if proto.HasExtension(nilMsg, desc) {
411		t.Error("HasExtension(nil) = true")
412	}
413
414	if _, err := proto.GetExtensions(nilMsg, []*proto.ExtensionDesc{desc}); !isNotExtendable(err) {
415		t.Errorf("GetExtensions(nil) = %q (wrong error)", err)
416	}
417
418	if _, err := proto.ExtensionDescs(nilMsg); !isNotExtendable(err) {
419		t.Errorf("ExtensionDescs(nil) = %q (wrong error)", err)
420	}
421
422	if err := proto.SetExtension(nilMsg, desc, nil); !isNotExtendable(err) {
423		t.Errorf("SetExtension(nil) = %q (wrong error)", err)
424	}
425
426	proto.ClearExtension(nilMsg, desc) // no-op
427	proto.ClearAllExtensions(nilMsg)   // no-op
428}
429
430func TestExtensionsRoundTrip(t *testing.T) {
431	msg := &pb.MyMessage{}
432	ext1 := &pb.Ext{
433		Data: proto.String("hi"),
434	}
435	ext2 := &pb.Ext{
436		Data: proto.String("there"),
437	}
438	exists := proto.HasExtension(msg, pb.E_Ext_More)
439	if exists {
440		t.Error("Extension More present unexpectedly")
441	}
442	if err := proto.SetExtension(msg, pb.E_Ext_More, ext1); err != nil {
443		t.Error(err)
444	}
445	if err := proto.SetExtension(msg, pb.E_Ext_More, ext2); err != nil {
446		t.Error(err)
447	}
448	e, err := proto.GetExtension(msg, pb.E_Ext_More)
449	if err != nil {
450		t.Error(err)
451	}
452	x, ok := e.(*pb.Ext)
453	if !ok {
454		t.Errorf("e has type %T, expected test_proto.Ext", e)
455	} else if *x.Data != "there" {
456		t.Errorf("SetExtension failed to overwrite, got %+v, not 'there'", x)
457	}
458	proto.ClearExtension(msg, pb.E_Ext_More)
459	if _, err = proto.GetExtension(msg, pb.E_Ext_More); err != proto.ErrMissingExtension {
460		t.Errorf("got %v, expected ErrMissingExtension", e)
461	}
462	if _, err := proto.GetExtension(msg, pb.E_X215); err == nil {
463		t.Error("expected bad extension error, got nil")
464	}
465	if err := proto.SetExtension(msg, pb.E_X215, 12); err == nil {
466		t.Error("expected extension err")
467	}
468	if err := proto.SetExtension(msg, pb.E_Ext_More, 12); err == nil {
469		t.Error("expected some sort of type mismatch error, got nil")
470	}
471}
472
473func TestNilExtension(t *testing.T) {
474	msg := &pb.MyMessage{
475		Count: proto.Int32(1),
476	}
477	if err := proto.SetExtension(msg, pb.E_Ext_Text, proto.String("hello")); err != nil {
478		t.Fatal(err)
479	}
480	if err := proto.SetExtension(msg, pb.E_Ext_More, (*pb.Ext)(nil)); err == nil {
481		t.Error("expected SetExtension to fail due to a nil extension")
482	} else if want := fmt.Sprintf("proto: SetExtension called with nil value of type %T", new(pb.Ext)); err.Error() != want {
483		t.Errorf("expected error %v, got %v", want, err)
484	}
485	// Note: if the behavior of Marshal is ever changed to ignore nil extensions, update
486	// this test to verify that E_Ext_Text is properly propagated through marshal->unmarshal.
487}
488
489func TestMarshalUnmarshalRepeatedExtension(t *testing.T) {
490	// Add a repeated extension to the result.
491	tests := []struct {
492		name string
493		ext  []*pb.ComplexExtension
494	}{
495		{
496			"two fields",
497			[]*pb.ComplexExtension{
498				{First: proto.Int32(7)},
499				{Second: proto.Int32(11)},
500			},
501		},
502		{
503			"repeated field",
504			[]*pb.ComplexExtension{
505				{Third: []int32{1000}},
506				{Third: []int32{2000}},
507			},
508		},
509		{
510			"two fields and repeated field",
511			[]*pb.ComplexExtension{
512				{Third: []int32{1000}},
513				{First: proto.Int32(9)},
514				{Second: proto.Int32(21)},
515				{Third: []int32{2000}},
516			},
517		},
518	}
519	for _, test := range tests {
520		// Marshal message with a repeated extension.
521		msg1 := new(pb.OtherMessage)
522		err := proto.SetExtension(msg1, pb.E_RComplex, test.ext)
523		if err != nil {
524			t.Fatalf("[%s] Error setting extension: %v", test.name, err)
525		}
526		b, err := proto.Marshal(msg1)
527		if err != nil {
528			t.Fatalf("[%s] Error marshaling message: %v", test.name, err)
529		}
530
531		// Unmarshal and read the merged proto.
532		msg2 := new(pb.OtherMessage)
533		err = proto.Unmarshal(b, msg2)
534		if err != nil {
535			t.Fatalf("[%s] Error unmarshaling message: %v", test.name, err)
536		}
537		e, err := proto.GetExtension(msg2, pb.E_RComplex)
538		if err != nil {
539			t.Fatalf("[%s] Error getting extension: %v", test.name, err)
540		}
541		ext := e.([]*pb.ComplexExtension)
542		if ext == nil {
543			t.Fatalf("[%s] Invalid extension", test.name)
544		}
545		if len(ext) != len(test.ext) {
546			t.Errorf("[%s] Wrong length of ComplexExtension: got: %v want: %v\n", test.name, len(ext), len(test.ext))
547		}
548		for i := range test.ext {
549			if !proto.Equal(ext[i], test.ext[i]) {
550				t.Errorf("[%s] Wrong value for ComplexExtension[%d]: got: %v want: %v\n", test.name, i, ext[i], test.ext[i])
551			}
552		}
553	}
554}
555
556func TestUnmarshalRepeatingNonRepeatedExtension(t *testing.T) {
557	// We may see multiple instances of the same extension in the wire
558	// format. For example, the proto compiler may encode custom options in
559	// this way. Here, we verify that we merge the extensions together.
560	tests := []struct {
561		name string
562		ext  []*pb.ComplexExtension
563	}{
564		{
565			"two fields",
566			[]*pb.ComplexExtension{
567				{First: proto.Int32(7)},
568				{Second: proto.Int32(11)},
569			},
570		},
571		{
572			"repeated field",
573			[]*pb.ComplexExtension{
574				{Third: []int32{1000}},
575				{Third: []int32{2000}},
576			},
577		},
578		{
579			"two fields and repeated field",
580			[]*pb.ComplexExtension{
581				{Third: []int32{1000}},
582				{First: proto.Int32(9)},
583				{Second: proto.Int32(21)},
584				{Third: []int32{2000}},
585			},
586		},
587	}
588	for _, test := range tests {
589		var buf bytes.Buffer
590		var want pb.ComplexExtension
591
592		// Generate a serialized representation of a repeated extension
593		// by catenating bytes together.
594		for i, e := range test.ext {
595			// Merge to create the wanted proto.
596			proto.Merge(&want, e)
597
598			// serialize the message
599			msg := new(pb.OtherMessage)
600			err := proto.SetExtension(msg, pb.E_Complex, e)
601			if err != nil {
602				t.Fatalf("[%s] Error setting extension %d: %v", test.name, i, err)
603			}
604			b, err := proto.Marshal(msg)
605			if err != nil {
606				t.Fatalf("[%s] Error marshaling message %d: %v", test.name, i, err)
607			}
608			buf.Write(b)
609		}
610
611		// Unmarshal and read the merged proto.
612		msg2 := new(pb.OtherMessage)
613		err := proto.Unmarshal(buf.Bytes(), msg2)
614		if err != nil {
615			t.Fatalf("[%s] Error unmarshaling message: %v", test.name, err)
616		}
617		e, err := proto.GetExtension(msg2, pb.E_Complex)
618		if err != nil {
619			t.Fatalf("[%s] Error getting extension: %v", test.name, err)
620		}
621		ext := e.(*pb.ComplexExtension)
622		if ext == nil {
623			t.Fatalf("[%s] Invalid extension", test.name)
624		}
625		if !proto.Equal(ext, &want) {
626			t.Errorf("[%s] Wrong value for ComplexExtension: got: %s want: %s\n", test.name, ext, &want)
627		}
628	}
629}
630
631func TestClearAllExtensions(t *testing.T) {
632	// unregistered extension
633	desc := &proto.ExtensionDesc{
634		ExtendedType:  (*pb.MyMessage)(nil),
635		ExtensionType: (*bool)(nil),
636		Field:         101010100,
637		Name:          "emptyextension",
638		Tag:           "varint,0,opt",
639	}
640	m := &pb.MyMessage{}
641	if proto.HasExtension(m, desc) {
642		t.Errorf("proto.HasExtension(%s): got true, want false", proto.MarshalTextString(m))
643	}
644	if err := proto.SetExtension(m, desc, proto.Bool(true)); err != nil {
645		t.Errorf("proto.SetExtension(m, desc, true): got error %q, want nil", err)
646	}
647	if !proto.HasExtension(m, desc) {
648		t.Errorf("proto.HasExtension(%s): got false, want true", proto.MarshalTextString(m))
649	}
650	proto.ClearAllExtensions(m)
651	if proto.HasExtension(m, desc) {
652		t.Errorf("proto.HasExtension(%s): got true, want false", proto.MarshalTextString(m))
653	}
654}
655
656func TestMarshalRace(t *testing.T) {
657	ext := &pb.Ext{}
658	m := &pb.MyMessage{Count: proto.Int32(4)}
659	if err := proto.SetExtension(m, pb.E_Ext_More, ext); err != nil {
660		t.Fatalf("proto.SetExtension(m, desc, true): got error %q, want nil", err)
661	}
662
663	b, err := proto.Marshal(m)
664	if err != nil {
665		t.Fatalf("Could not marshal message: %v", err)
666	}
667	if err := proto.Unmarshal(b, m); err != nil {
668		t.Fatalf("Could not unmarshal message: %v", err)
669	}
670	// after Unmarshal, the extension is in undecoded form.
671	// GetExtension will decode it lazily. Make sure this does
672	// not race against Marshal.
673
674	wg := sync.WaitGroup{}
675	errs := make(chan error, 3)
676	for n := 3; n > 0; n-- {
677		wg.Add(1)
678		go func() {
679			defer wg.Done()
680			_, err := proto.Marshal(m)
681			errs <- err
682		}()
683	}
684	wg.Wait()
685	close(errs)
686
687	for err = range errs {
688		if err != nil {
689			t.Fatal(err)
690		}
691	}
692}
693