1// Go support for Protocol Buffers - Google's data interchange format
2//
3// Copyright 2014 The Go Authors.  All rights reserved.
4// https://github.com/golang/protobuf
5//
6// Redistribution and use in source and binary forms, with or without
7// modification, are permitted provided that the following conditions are
8// met:
9//
10//     * Redistributions of source code must retain the above copyright
11// notice, this list of conditions and the following disclaimer.
12//     * Redistributions in binary form must reproduce the above
13// copyright notice, this list of conditions and the following disclaimer
14// in the documentation and/or other materials provided with the
15// distribution.
16//     * Neither the name of Google Inc. nor the names of its
17// contributors may be used to endorse or promote products derived from
18// this software without specific prior written permission.
19//
20// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
21// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
22// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
23// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
24// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
25// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
26// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
27// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
28// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
29// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
30// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
31
32package proto_test
33
34import (
35	"bytes"
36	"fmt"
37	"io"
38	"reflect"
39	"sort"
40	"strings"
41	"testing"
42
43	"github.com/gogo/protobuf/proto"
44	pb "github.com/gogo/protobuf/proto/test_proto"
45)
46
47func TestGetExtensionsWithMissingExtensions(t *testing.T) {
48	msg := &pb.MyMessage{}
49	ext1 := &pb.Ext{}
50	if err := proto.SetExtension(msg, pb.E_Ext_More, ext1); err != nil {
51		t.Fatalf("Could not set ext1: %s", err)
52	}
53	exts, err := proto.GetExtensions(msg, []*proto.ExtensionDesc{
54		pb.E_Ext_More,
55		pb.E_Ext_Text,
56	})
57	if err != nil {
58		t.Fatalf("GetExtensions() failed: %s", err)
59	}
60	if exts[0] != ext1 {
61		t.Errorf("ext1 not in returned extensions: %T %v", exts[0], exts[0])
62	}
63	if exts[1] != nil {
64		t.Errorf("ext2 in returned extensions: %T %v", exts[1], exts[1])
65	}
66}
67
68func TestGetExtensionWithEmptyBuffer(t *testing.T) {
69	// Make sure that GetExtension returns an error if its
70	// undecoded buffer is empty.
71	msg := &pb.MyMessage{}
72	proto.SetRawExtension(msg, pb.E_Ext_More.Field, []byte{})
73	_, err := proto.GetExtension(msg, pb.E_Ext_More)
74	if want := io.ErrUnexpectedEOF; err != want {
75		t.Errorf("unexpected error in GetExtension from empty buffer: got %v, want %v", err, want)
76	}
77}
78
79func TestGetExtensionForIncompleteDesc(t *testing.T) {
80	msg := &pb.MyMessage{Count: proto.Int32(0)}
81	extdesc1 := &proto.ExtensionDesc{
82		ExtendedType:  (*pb.MyMessage)(nil),
83		ExtensionType: (*bool)(nil),
84		Field:         123456789,
85		Name:          "a.b",
86		Tag:           "varint,123456789,opt",
87	}
88	ext1 := proto.Bool(true)
89	if err := proto.SetExtension(msg, extdesc1, ext1); err != nil {
90		t.Fatalf("Could not set ext1: %s", err)
91	}
92	extdesc2 := &proto.ExtensionDesc{
93		ExtendedType:  (*pb.MyMessage)(nil),
94		ExtensionType: ([]byte)(nil),
95		Field:         123456790,
96		Name:          "a.c",
97		Tag:           "bytes,123456790,opt",
98	}
99	ext2 := []byte{0, 1, 2, 3, 4, 5, 6, 7}
100	if err := proto.SetExtension(msg, extdesc2, ext2); err != nil {
101		t.Fatalf("Could not set ext2: %s", err)
102	}
103	extdesc3 := &proto.ExtensionDesc{
104		ExtendedType:  (*pb.MyMessage)(nil),
105		ExtensionType: (*pb.Ext)(nil),
106		Field:         123456791,
107		Name:          "a.d",
108		Tag:           "bytes,123456791,opt",
109	}
110	ext3 := &pb.Ext{Data: proto.String("foo")}
111	if err := proto.SetExtension(msg, extdesc3, ext3); err != nil {
112		t.Fatalf("Could not set ext3: %s", err)
113	}
114
115	b, err := proto.Marshal(msg)
116	if err != nil {
117		t.Fatalf("Could not marshal msg: %v", err)
118	}
119	if err := proto.Unmarshal(b, msg); err != nil {
120		t.Fatalf("Could not unmarshal into msg: %v", err)
121	}
122
123	var expected proto.Buffer
124	if err := expected.EncodeVarint(uint64((extdesc1.Field << 3) | proto.WireVarint)); err != nil {
125		t.Fatalf("failed to compute expected prefix for ext1: %s", err)
126	}
127	if err := expected.EncodeVarint(1 /* bool true */); err != nil {
128		t.Fatalf("failed to compute expected value for ext1: %s", err)
129	}
130
131	if b, err := proto.GetExtension(msg, &proto.ExtensionDesc{Field: extdesc1.Field}); err != nil {
132		t.Fatalf("Failed to get raw value for ext1: %s", err)
133	} else if !reflect.DeepEqual(b, expected.Bytes()) {
134		t.Fatalf("Raw value for ext1: got %v, want %v", b, expected.Bytes())
135	}
136
137	expected = proto.Buffer{} // reset
138	if err := expected.EncodeVarint(uint64((extdesc2.Field << 3) | proto.WireBytes)); err != nil {
139		t.Fatalf("failed to compute expected prefix for ext2: %s", err)
140	}
141	if err := expected.EncodeRawBytes(ext2); err != nil {
142		t.Fatalf("failed to compute expected value for ext2: %s", err)
143	}
144
145	if b, err := proto.GetExtension(msg, &proto.ExtensionDesc{Field: extdesc2.Field}); err != nil {
146		t.Fatalf("Failed to get raw value for ext2: %s", err)
147	} else if !reflect.DeepEqual(b, expected.Bytes()) {
148		t.Fatalf("Raw value for ext2: got %v, want %v", b, expected.Bytes())
149	}
150
151	expected = proto.Buffer{} // reset
152	if err := expected.EncodeVarint(uint64((extdesc3.Field << 3) | proto.WireBytes)); err != nil {
153		t.Fatalf("failed to compute expected prefix for ext3: %s", err)
154	}
155	if b, err := proto.Marshal(ext3); err != nil {
156		t.Fatalf("failed to compute expected value for ext3: %s", err)
157	} else if err := expected.EncodeRawBytes(b); err != nil {
158		t.Fatalf("failed to compute expected value for ext3: %s", err)
159	}
160
161	if b, err := proto.GetExtension(msg, &proto.ExtensionDesc{Field: extdesc3.Field}); err != nil {
162		t.Fatalf("Failed to get raw value for ext3: %s", err)
163	} else if !reflect.DeepEqual(b, expected.Bytes()) {
164		t.Fatalf("Raw value for ext3: got %v, want %v", b, expected.Bytes())
165	}
166}
167
168func TestExtensionDescsWithUnregisteredExtensions(t *testing.T) {
169	msg := &pb.MyMessage{Count: proto.Int32(0)}
170	extdesc1 := pb.E_Ext_More
171	if descs, err := proto.ExtensionDescs(msg); len(descs) != 0 || err != nil {
172		t.Errorf("proto.ExtensionDescs: got %d descs, error %v; want 0, nil", len(descs), err)
173	}
174
175	ext1 := &pb.Ext{}
176	if err := proto.SetExtension(msg, extdesc1, ext1); err != nil {
177		t.Fatalf("Could not set ext1: %s", err)
178	}
179	extdesc2 := &proto.ExtensionDesc{
180		ExtendedType:  (*pb.MyMessage)(nil),
181		ExtensionType: (*bool)(nil),
182		Field:         123456789,
183		Name:          "a.b",
184		Tag:           "varint,123456789,opt",
185	}
186	ext2 := proto.Bool(false)
187	if err := proto.SetExtension(msg, extdesc2, ext2); err != nil {
188		t.Fatalf("Could not set ext2: %s", err)
189	}
190
191	b, err := proto.Marshal(msg)
192	if err != nil {
193		t.Fatalf("Could not marshal msg: %v", err)
194	}
195	if err = proto.Unmarshal(b, msg); err != nil {
196		t.Fatalf("Could not unmarshal into msg: %v", err)
197	}
198
199	descs, err := proto.ExtensionDescs(msg)
200	if err != nil {
201		t.Fatalf("proto.ExtensionDescs: got error %v", err)
202	}
203	sortExtDescs(descs)
204	wantDescs := []*proto.ExtensionDesc{extdesc1, {Field: extdesc2.Field}}
205	if !reflect.DeepEqual(descs, wantDescs) {
206		t.Errorf("proto.ExtensionDescs(msg) sorted extension ids: got %+v, want %+v", descs, wantDescs)
207	}
208}
209
210type ExtensionDescSlice []*proto.ExtensionDesc
211
212func (s ExtensionDescSlice) Len() int           { return len(s) }
213func (s ExtensionDescSlice) Less(i, j int) bool { return s[i].Field < s[j].Field }
214func (s ExtensionDescSlice) Swap(i, j int)      { s[i], s[j] = s[j], s[i] }
215
216func sortExtDescs(s []*proto.ExtensionDesc) {
217	sort.Sort(ExtensionDescSlice(s))
218}
219
220func TestGetExtensionStability(t *testing.T) {
221	check := func(m *pb.MyMessage) bool {
222		ext1, err := proto.GetExtension(m, pb.E_Ext_More)
223		if err != nil {
224			t.Fatalf("GetExtension() failed: %s", err)
225		}
226		ext2, err := proto.GetExtension(m, pb.E_Ext_More)
227		if err != nil {
228			t.Fatalf("GetExtension() failed: %s", err)
229		}
230		return ext1 == ext2
231	}
232	msg := &pb.MyMessage{Count: proto.Int32(4)}
233	ext0 := &pb.Ext{}
234	if err := proto.SetExtension(msg, pb.E_Ext_More, ext0); err != nil {
235		t.Fatalf("Could not set ext1: %s", ext0)
236	}
237	if !check(msg) {
238		t.Errorf("GetExtension() not stable before marshaling")
239	}
240	bb, err := proto.Marshal(msg)
241	if err != nil {
242		t.Fatalf("Marshal() failed: %s", err)
243	}
244	msg1 := &pb.MyMessage{}
245	err = proto.Unmarshal(bb, msg1)
246	if err != nil {
247		t.Fatalf("Unmarshal() failed: %s", err)
248	}
249	if !check(msg1) {
250		t.Errorf("GetExtension() not stable after unmarshaling")
251	}
252}
253
254func TestGetExtensionDefaults(t *testing.T) {
255	var setFloat64 float64 = 1
256	var setFloat32 float32 = 2
257	var setInt32 int32 = 3
258	var setInt64 int64 = 4
259	var setUint32 uint32 = 5
260	var setUint64 uint64 = 6
261	var setBool = true
262	var setBool2 = false
263	var setString = "Goodnight string"
264	var setBytes = []byte("Goodnight bytes")
265	var setEnum = pb.DefaultsMessage_TWO
266
267	type testcase struct {
268		ext  *proto.ExtensionDesc // Extension we are testing.
269		want interface{}          // Expected value of extension, or nil (meaning that GetExtension will fail).
270		def  interface{}          // Expected value of extension after ClearExtension().
271	}
272	tests := []testcase{
273		{pb.E_NoDefaultDouble, setFloat64, nil},
274		{pb.E_NoDefaultFloat, setFloat32, nil},
275		{pb.E_NoDefaultInt32, setInt32, nil},
276		{pb.E_NoDefaultInt64, setInt64, nil},
277		{pb.E_NoDefaultUint32, setUint32, nil},
278		{pb.E_NoDefaultUint64, setUint64, nil},
279		{pb.E_NoDefaultSint32, setInt32, nil},
280		{pb.E_NoDefaultSint64, setInt64, nil},
281		{pb.E_NoDefaultFixed32, setUint32, nil},
282		{pb.E_NoDefaultFixed64, setUint64, nil},
283		{pb.E_NoDefaultSfixed32, setInt32, nil},
284		{pb.E_NoDefaultSfixed64, setInt64, nil},
285		{pb.E_NoDefaultBool, setBool, nil},
286		{pb.E_NoDefaultBool, setBool2, nil},
287		{pb.E_NoDefaultString, setString, nil},
288		{pb.E_NoDefaultBytes, setBytes, nil},
289		{pb.E_NoDefaultEnum, setEnum, nil},
290		{pb.E_DefaultDouble, setFloat64, float64(3.1415)},
291		{pb.E_DefaultFloat, setFloat32, float32(3.14)},
292		{pb.E_DefaultInt32, setInt32, int32(42)},
293		{pb.E_DefaultInt64, setInt64, int64(43)},
294		{pb.E_DefaultUint32, setUint32, uint32(44)},
295		{pb.E_DefaultUint64, setUint64, uint64(45)},
296		{pb.E_DefaultSint32, setInt32, int32(46)},
297		{pb.E_DefaultSint64, setInt64, int64(47)},
298		{pb.E_DefaultFixed32, setUint32, uint32(48)},
299		{pb.E_DefaultFixed64, setUint64, uint64(49)},
300		{pb.E_DefaultSfixed32, setInt32, int32(50)},
301		{pb.E_DefaultSfixed64, setInt64, int64(51)},
302		{pb.E_DefaultBool, setBool, true},
303		{pb.E_DefaultBool, setBool2, true},
304		{pb.E_DefaultString, setString, "Hello, string,def=foo"},
305		{pb.E_DefaultBytes, setBytes, []byte("Hello, bytes")},
306		{pb.E_DefaultEnum, setEnum, pb.DefaultsMessage_ONE},
307	}
308
309	checkVal := func(test testcase, msg *pb.DefaultsMessage, valWant interface{}) error {
310		val, err := proto.GetExtension(msg, test.ext)
311		if err != nil {
312			if valWant != nil {
313				return fmt.Errorf("GetExtension(): %s", err)
314			}
315			if want := proto.ErrMissingExtension; err != want {
316				return fmt.Errorf("Unexpected error: got %v, want %v", err, want)
317			}
318			return nil
319		}
320
321		// All proto2 extension values are either a pointer to a value or a slice of values.
322		ty := reflect.TypeOf(val)
323		tyWant := reflect.TypeOf(test.ext.ExtensionType)
324		if got, want := ty, tyWant; got != want {
325			return fmt.Errorf("unexpected reflect.TypeOf(): got %v want %v", got, want)
326		}
327		tye := ty.Elem()
328		tyeWant := tyWant.Elem()
329		if got, want := tye, tyeWant; got != want {
330			return fmt.Errorf("unexpected reflect.TypeOf().Elem(): got %v want %v", got, want)
331		}
332
333		// Check the name of the type of the value.
334		// If it is an enum it will be type int32 with the name of the enum.
335		if got, want := tye.Name(), tye.Name(); got != want {
336			return fmt.Errorf("unexpected reflect.TypeOf().Elem().Name(): got %v want %v", got, want)
337		}
338
339		// Check that value is what we expect.
340		// If we have a pointer in val, get the value it points to.
341		valExp := val
342		if ty.Kind() == reflect.Ptr {
343			valExp = reflect.ValueOf(val).Elem().Interface()
344		}
345		if got, want := valExp, valWant; !reflect.DeepEqual(got, want) {
346			return fmt.Errorf("unexpected reflect.DeepEqual(): got %v want %v", got, want)
347		}
348
349		return nil
350	}
351
352	setTo := func(test testcase) interface{} {
353		setTo := reflect.ValueOf(test.want)
354		if typ := reflect.TypeOf(test.ext.ExtensionType); typ.Kind() == reflect.Ptr {
355			setTo = reflect.New(typ).Elem()
356			setTo.Set(reflect.New(setTo.Type().Elem()))
357			setTo.Elem().Set(reflect.ValueOf(test.want))
358		}
359		return setTo.Interface()
360	}
361
362	for _, test := range tests {
363		msg := &pb.DefaultsMessage{}
364		name := test.ext.Name
365
366		// Check the initial value.
367		if err := checkVal(test, msg, test.def); err != nil {
368			t.Errorf("%s: %v", name, err)
369		}
370
371		// Set the per-type value and check value.
372		name = fmt.Sprintf("%s (set to %T %v)", name, test.want, test.want)
373		if err := proto.SetExtension(msg, test.ext, setTo(test)); err != nil {
374			t.Errorf("%s: SetExtension(): %v", name, err)
375			continue
376		}
377		if err := checkVal(test, msg, test.want); err != nil {
378			t.Errorf("%s: %v", name, err)
379			continue
380		}
381
382		// Set and check the value.
383		name += " (cleared)"
384		proto.ClearExtension(msg, test.ext)
385		if err := checkVal(test, msg, test.def); err != nil {
386			t.Errorf("%s: %v", name, err)
387		}
388	}
389}
390
391func TestNilMessage(t *testing.T) {
392	name := "nil interface"
393	if got, err := proto.GetExtension(nil, pb.E_Ext_More); err == nil {
394		t.Errorf("%s: got %T %v, expected to fail", name, got, got)
395	} else if !strings.Contains(err.Error(), "extendable") {
396		t.Errorf("%s: got error %v, expected not-extendable error", name, err)
397	}
398
399	// Regression tests: all functions of the Extension API
400	// used to panic when passed (*M)(nil), where M is a concrete message
401	// type.  Now they handle this gracefully as a no-op or reported error.
402	var nilMsg *pb.MyMessage
403	desc := pb.E_Ext_More
404
405	isNotExtendable := func(err error) bool {
406		return strings.Contains(fmt.Sprint(err), "not extendable")
407	}
408
409	if proto.HasExtension(nilMsg, desc) {
410		t.Error("HasExtension(nil) = true")
411	}
412
413	if _, err := proto.GetExtensions(nilMsg, []*proto.ExtensionDesc{desc}); !isNotExtendable(err) {
414		t.Errorf("GetExtensions(nil) = %q (wrong error)", err)
415	}
416
417	if _, err := proto.ExtensionDescs(nilMsg); !isNotExtendable(err) {
418		t.Errorf("ExtensionDescs(nil) = %q (wrong error)", err)
419	}
420
421	if err := proto.SetExtension(nilMsg, desc, nil); !isNotExtendable(err) {
422		t.Errorf("SetExtension(nil) = %q (wrong error)", err)
423	}
424
425	proto.ClearExtension(nilMsg, desc) // no-op
426	proto.ClearAllExtensions(nilMsg)   // no-op
427}
428
429func TestExtensionsRoundTrip(t *testing.T) {
430	msg := &pb.MyMessage{}
431	ext1 := &pb.Ext{
432		Data: proto.String("hi"),
433	}
434	ext2 := &pb.Ext{
435		Data: proto.String("there"),
436	}
437	exists := proto.HasExtension(msg, pb.E_Ext_More)
438	if exists {
439		t.Error("Extension More present unexpectedly")
440	}
441	if err := proto.SetExtension(msg, pb.E_Ext_More, ext1); err != nil {
442		t.Error(err)
443	}
444	if err := proto.SetExtension(msg, pb.E_Ext_More, ext2); err != nil {
445		t.Error(err)
446	}
447	e, err := proto.GetExtension(msg, pb.E_Ext_More)
448	if err != nil {
449		t.Error(err)
450	}
451	x, ok := e.(*pb.Ext)
452	if !ok {
453		t.Errorf("e has type %T, expected test_proto.Ext", e)
454	} else if *x.Data != "there" {
455		t.Errorf("SetExtension failed to overwrite, got %+v, not 'there'", x)
456	}
457	proto.ClearExtension(msg, pb.E_Ext_More)
458	if _, err = proto.GetExtension(msg, pb.E_Ext_More); err != proto.ErrMissingExtension {
459		t.Errorf("got %v, expected ErrMissingExtension", e)
460	}
461	if _, err := proto.GetExtension(msg, pb.E_X215); err == nil {
462		t.Error("expected bad extension error, got nil")
463	}
464	if err := proto.SetExtension(msg, pb.E_X215, 12); err == nil {
465		t.Error("expected extension err")
466	}
467	if err := proto.SetExtension(msg, pb.E_Ext_More, 12); err == nil {
468		t.Error("expected some sort of type mismatch error, got nil")
469	}
470}
471
472func TestNilExtension(t *testing.T) {
473	msg := &pb.MyMessage{
474		Count: proto.Int32(1),
475	}
476	if err := proto.SetExtension(msg, pb.E_Ext_Text, proto.String("hello")); err != nil {
477		t.Fatal(err)
478	}
479	if err := proto.SetExtension(msg, pb.E_Ext_More, (*pb.Ext)(nil)); err == nil {
480		t.Error("expected SetExtension to fail due to a nil extension")
481	} else if want := fmt.Sprintf("proto: SetExtension called with nil value of type %T", new(pb.Ext)); err.Error() != want {
482		t.Errorf("expected error %v, got %v", want, err)
483	}
484	// Note: if the behavior of Marshal is ever changed to ignore nil extensions, update
485	// this test to verify that E_Ext_Text is properly propagated through marshal->unmarshal.
486}
487
488func TestMarshalUnmarshalRepeatedExtension(t *testing.T) {
489	// Add a repeated extension to the result.
490	tests := []struct {
491		name string
492		ext  []*pb.ComplexExtension
493	}{
494		{
495			"two fields",
496			[]*pb.ComplexExtension{
497				{First: proto.Int32(7)},
498				{Second: proto.Int32(11)},
499			},
500		},
501		{
502			"repeated field",
503			[]*pb.ComplexExtension{
504				{Third: []int32{1000}},
505				{Third: []int32{2000}},
506			},
507		},
508		{
509			"two fields and repeated field",
510			[]*pb.ComplexExtension{
511				{Third: []int32{1000}},
512				{First: proto.Int32(9)},
513				{Second: proto.Int32(21)},
514				{Third: []int32{2000}},
515			},
516		},
517	}
518	for _, test := range tests {
519		// Marshal message with a repeated extension.
520		msg1 := new(pb.OtherMessage)
521		err := proto.SetExtension(msg1, pb.E_RComplex, test.ext)
522		if err != nil {
523			t.Fatalf("[%s] Error setting extension: %v", test.name, err)
524		}
525		b, err := proto.Marshal(msg1)
526		if err != nil {
527			t.Fatalf("[%s] Error marshaling message: %v", test.name, err)
528		}
529
530		// Unmarshal and read the merged proto.
531		msg2 := new(pb.OtherMessage)
532		err = proto.Unmarshal(b, msg2)
533		if err != nil {
534			t.Fatalf("[%s] Error unmarshaling message: %v", test.name, err)
535		}
536		e, err := proto.GetExtension(msg2, pb.E_RComplex)
537		if err != nil {
538			t.Fatalf("[%s] Error getting extension: %v", test.name, err)
539		}
540		ext := e.([]*pb.ComplexExtension)
541		if ext == nil {
542			t.Fatalf("[%s] Invalid extension", test.name)
543		}
544		if len(ext) != len(test.ext) {
545			t.Errorf("[%s] Wrong length of ComplexExtension: got: %v want: %v\n", test.name, len(ext), len(test.ext))
546		}
547		for i := range test.ext {
548			if !proto.Equal(ext[i], test.ext[i]) {
549				t.Errorf("[%s] Wrong value for ComplexExtension[%d]: got: %v want: %v\n", test.name, i, ext[i], test.ext[i])
550			}
551		}
552	}
553}
554
555func TestUnmarshalRepeatingNonRepeatedExtension(t *testing.T) {
556	// We may see multiple instances of the same extension in the wire
557	// format. For example, the proto compiler may encode custom options in
558	// this way. Here, we verify that we merge the extensions together.
559	tests := []struct {
560		name string
561		ext  []*pb.ComplexExtension
562	}{
563		{
564			"two fields",
565			[]*pb.ComplexExtension{
566				{First: proto.Int32(7)},
567				{Second: proto.Int32(11)},
568			},
569		},
570		{
571			"repeated field",
572			[]*pb.ComplexExtension{
573				{Third: []int32{1000}},
574				{Third: []int32{2000}},
575			},
576		},
577		{
578			"two fields and repeated field",
579			[]*pb.ComplexExtension{
580				{Third: []int32{1000}},
581				{First: proto.Int32(9)},
582				{Second: proto.Int32(21)},
583				{Third: []int32{2000}},
584			},
585		},
586	}
587	for _, test := range tests {
588		var buf bytes.Buffer
589		var want pb.ComplexExtension
590
591		// Generate a serialized representation of a repeated extension
592		// by catenating bytes together.
593		for i, e := range test.ext {
594			// Merge to create the wanted proto.
595			proto.Merge(&want, e)
596
597			// serialize the message
598			msg := new(pb.OtherMessage)
599			err := proto.SetExtension(msg, pb.E_Complex, e)
600			if err != nil {
601				t.Fatalf("[%s] Error setting extension %d: %v", test.name, i, err)
602			}
603			b, err := proto.Marshal(msg)
604			if err != nil {
605				t.Fatalf("[%s] Error marshaling message %d: %v", test.name, i, err)
606			}
607			buf.Write(b)
608		}
609
610		// Unmarshal and read the merged proto.
611		msg2 := new(pb.OtherMessage)
612		err := proto.Unmarshal(buf.Bytes(), msg2)
613		if err != nil {
614			t.Fatalf("[%s] Error unmarshaling message: %v", test.name, err)
615		}
616		e, err := proto.GetExtension(msg2, pb.E_Complex)
617		if err != nil {
618			t.Fatalf("[%s] Error getting extension: %v", test.name, err)
619		}
620		ext := e.(*pb.ComplexExtension)
621		if ext == nil {
622			t.Fatalf("[%s] Invalid extension", test.name)
623		}
624		if !proto.Equal(ext, &want) {
625			t.Errorf("[%s] Wrong value for ComplexExtension: got: %v want: %v\n", test.name, ext, &want)
626
627		}
628	}
629}
630
631func TestClearAllExtensions(t *testing.T) {
632	// unregistered extension
633	desc := &proto.ExtensionDesc{
634		ExtendedType:  (*pb.MyMessage)(nil),
635		ExtensionType: (*bool)(nil),
636		Field:         101010100,
637		Name:          "emptyextension",
638		Tag:           "varint,0,opt",
639	}
640	m := &pb.MyMessage{}
641	if proto.HasExtension(m, desc) {
642		t.Errorf("proto.HasExtension(%s): got true, want false", proto.MarshalTextString(m))
643	}
644	if err := proto.SetExtension(m, desc, proto.Bool(true)); err != nil {
645		t.Errorf("proto.SetExtension(m, desc, true): got error %q, want nil", err)
646	}
647	if !proto.HasExtension(m, desc) {
648		t.Errorf("proto.HasExtension(%s): got false, want true", proto.MarshalTextString(m))
649	}
650	proto.ClearAllExtensions(m)
651	if proto.HasExtension(m, desc) {
652		t.Errorf("proto.HasExtension(%s): got true, want false", proto.MarshalTextString(m))
653	}
654}
655
656func TestMarshalRace(t *testing.T) {
657	ext := &pb.Ext{}
658	m := &pb.MyMessage{Count: proto.Int32(4)}
659	if err := proto.SetExtension(m, pb.E_Ext_More, ext); err != nil {
660		t.Fatalf("proto.SetExtension(m, desc, true): got error %q, want nil", err)
661	}
662
663	b, err := proto.Marshal(m)
664	if err != nil {
665		t.Fatalf("Could not marshal message: %v", err)
666	}
667	if err := proto.Unmarshal(b, m); err != nil {
668		t.Fatalf("Could not unmarshal message: %v", err)
669	}
670	// after Unmarshal, the extension is in undecoded form.
671	// GetExtension will decode it lazily. Make sure this does
672	// not race against Marshal.
673
674	errChan := make(chan error, 6)
675	for n := 3; n > 0; n-- {
676		go func() {
677			_, err := proto.Marshal(m)
678			errChan <- err
679		}()
680		go func() {
681			_, err := proto.GetExtension(m, pb.E_Ext_More)
682			errChan <- err
683		}()
684	}
685	for i := 0; i < 6; i++ {
686		err := <-errChan
687		if err != nil {
688			t.Fatal(err)
689		}
690	}
691}
692