1// Copyright 2019 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5// Package prototest exercises protobuf reflection.
6package prototest
7
8import (
9	"bytes"
10	"fmt"
11	"math"
12	"reflect"
13	"sort"
14	"strings"
15	"testing"
16
17	"google.golang.org/protobuf/encoding/prototext"
18	"google.golang.org/protobuf/encoding/protowire"
19	"google.golang.org/protobuf/proto"
20	"google.golang.org/protobuf/reflect/protoreflect"
21	pref "google.golang.org/protobuf/reflect/protoreflect"
22	"google.golang.org/protobuf/reflect/protoregistry"
23)
24
25// TODO: Test invalid field descriptors or oneof descriptors.
26// TODO: This should test the functionality that can be provided by fast-paths.
27
28// Message tests a message implemention.
29type Message struct {
30	// Resolver is used to determine the list of extension fields to test with.
31	// If nil, this defaults to using protoregistry.GlobalTypes.
32	Resolver interface {
33		FindExtensionByName(field pref.FullName) (pref.ExtensionType, error)
34		FindExtensionByNumber(message pref.FullName, field pref.FieldNumber) (pref.ExtensionType, error)
35		RangeExtensionsByMessage(message pref.FullName, f func(pref.ExtensionType) bool)
36	}
37}
38
39// Test performs tests on a MessageType implementation.
40func (test Message) Test(t testing.TB, mt pref.MessageType) {
41	testType(t, mt)
42
43	md := mt.Descriptor()
44	m1 := mt.New()
45	for i := 0; i < md.Fields().Len(); i++ {
46		fd := md.Fields().Get(i)
47		testField(t, m1, fd)
48	}
49	if test.Resolver == nil {
50		test.Resolver = protoregistry.GlobalTypes
51	}
52	var extTypes []pref.ExtensionType
53	test.Resolver.RangeExtensionsByMessage(md.FullName(), func(e pref.ExtensionType) bool {
54		extTypes = append(extTypes, e)
55		return true
56	})
57	for _, xt := range extTypes {
58		testField(t, m1, xt.TypeDescriptor())
59	}
60	for i := 0; i < md.Oneofs().Len(); i++ {
61		testOneof(t, m1, md.Oneofs().Get(i))
62	}
63	testUnknown(t, m1)
64
65	// Test round-trip marshal/unmarshal.
66	m2 := mt.New().Interface()
67	populateMessage(m2.ProtoReflect(), 1, nil)
68	for _, xt := range extTypes {
69		m2.ProtoReflect().Set(xt.TypeDescriptor(), newValue(m2.ProtoReflect(), xt.TypeDescriptor(), 1, nil))
70	}
71	b, err := proto.MarshalOptions{
72		AllowPartial: true,
73	}.Marshal(m2)
74	if err != nil {
75		t.Errorf("Marshal() = %v, want nil\n%v", err, prototext.Format(m2))
76	}
77	m3 := mt.New().Interface()
78	if err := (proto.UnmarshalOptions{
79		AllowPartial: true,
80		Resolver:     test.Resolver,
81	}.Unmarshal(b, m3)); err != nil {
82		t.Errorf("Unmarshal() = %v, want nil\n%v", err, prototext.Format(m2))
83	}
84	if !proto.Equal(m2, m3) {
85		t.Errorf("round-trip marshal/unmarshal did not preserve message\nOriginal:\n%v\nNew:\n%v", prototext.Format(m2), prototext.Format(m3))
86	}
87}
88
89func testType(t testing.TB, mt pref.MessageType) {
90	m := mt.New().Interface()
91	want := reflect.TypeOf(m)
92	if got := reflect.TypeOf(m.ProtoReflect().Interface()); got != want {
93		t.Errorf("type mismatch: reflect.TypeOf(m) != reflect.TypeOf(m.ProtoReflect().Interface()): %v != %v", got, want)
94	}
95	if got := reflect.TypeOf(m.ProtoReflect().New().Interface()); got != want {
96		t.Errorf("type mismatch: reflect.TypeOf(m) != reflect.TypeOf(m.ProtoReflect().New().Interface()): %v != %v", got, want)
97	}
98	if got := reflect.TypeOf(m.ProtoReflect().Type().Zero().Interface()); got != want {
99		t.Errorf("type mismatch: reflect.TypeOf(m) != reflect.TypeOf(m.ProtoReflect().Type().Zero().Interface()): %v != %v", got, want)
100	}
101	if mt, ok := mt.(pref.MessageFieldTypes); ok {
102		testFieldTypes(t, mt)
103	}
104}
105
106func testFieldTypes(t testing.TB, mt pref.MessageFieldTypes) {
107	descName := func(d pref.Descriptor) pref.FullName {
108		if d == nil {
109			return "<nil>"
110		}
111		return d.FullName()
112	}
113	typeName := func(mt pref.MessageType) pref.FullName {
114		if mt == nil {
115			return "<nil>"
116		}
117		return mt.Descriptor().FullName()
118	}
119	adjustExpr := func(idx int, expr string) string {
120		expr = strings.Replace(expr, "fd.", "md.Fields().Get(i).", -1)
121		expr = strings.Replace(expr, "(fd)", "(md.Fields().Get(i))", -1)
122		expr = strings.Replace(expr, "mti.", "mt.Message(i).", -1)
123		expr = strings.Replace(expr, "(i)", fmt.Sprintf("(%d)", idx), -1)
124		return expr
125	}
126	checkEnumDesc := func(idx int, gotExpr, wantExpr string, got, want protoreflect.EnumDescriptor) {
127		if got != want {
128			t.Errorf("descriptor mismatch: %v != %v: %v != %v", adjustExpr(idx, gotExpr), adjustExpr(idx, wantExpr), descName(got), descName(want))
129		}
130	}
131	checkMessageDesc := func(idx int, gotExpr, wantExpr string, got, want protoreflect.MessageDescriptor) {
132		if got != want {
133			t.Errorf("descriptor mismatch: %v != %v: %v != %v", adjustExpr(idx, gotExpr), adjustExpr(idx, wantExpr), descName(got), descName(want))
134		}
135	}
136	checkMessageType := func(idx int, gotExpr, wantExpr string, got, want protoreflect.MessageType) {
137		if got != want {
138			t.Errorf("type mismatch: %v != %v: %v != %v", adjustExpr(idx, gotExpr), adjustExpr(idx, wantExpr), typeName(got), typeName(want))
139		}
140	}
141
142	fds := mt.Descriptor().Fields()
143	m := mt.New()
144	for i := 0; i < fds.Len(); i++ {
145		fd := fds.Get(i)
146		switch {
147		case fd.IsList():
148			if fd.Enum() != nil {
149				checkEnumDesc(i,
150					"mt.Enum(i).Descriptor()", "fd.Enum()",
151					mt.Enum(i).Descriptor(), fd.Enum())
152			}
153			if fd.Message() != nil {
154				checkMessageDesc(i,
155					"mt.Message(i).Descriptor()", "fd.Message()",
156					mt.Message(i).Descriptor(), fd.Message())
157				checkMessageType(i,
158					"mt.Message(i)", "m.NewField(fd).List().NewElement().Message().Type()",
159					mt.Message(i), m.NewField(fd).List().NewElement().Message().Type())
160			}
161		case fd.IsMap():
162			mti := mt.Message(i)
163			if m := mti.New(); m != nil {
164				checkMessageDesc(i,
165					"m.Descriptor()", "fd.Message()",
166					m.Descriptor(), fd.Message())
167			}
168			if m := mti.Zero(); m != nil {
169				checkMessageDesc(i,
170					"m.Descriptor()", "fd.Message()",
171					m.Descriptor(), fd.Message())
172			}
173			checkMessageDesc(i,
174				"mti.Descriptor()", "fd.Message()",
175				mti.Descriptor(), fd.Message())
176			if mti := mti.(pref.MessageFieldTypes); mti != nil {
177				if fd.MapValue().Enum() != nil {
178					checkEnumDesc(i,
179						"mti.Enum(fd.MapValue().Index()).Descriptor()", "fd.MapValue().Enum()",
180						mti.Enum(fd.MapValue().Index()).Descriptor(), fd.MapValue().Enum())
181				}
182				if fd.MapValue().Message() != nil {
183					checkMessageDesc(i,
184						"mti.Message(fd.MapValue().Index()).Descriptor()", "fd.MapValue().Message()",
185						mti.Message(fd.MapValue().Index()).Descriptor(), fd.MapValue().Message())
186					checkMessageType(i,
187						"mti.Message(fd.MapValue().Index())", "m.NewField(fd).Map().NewValue().Message().Type()",
188						mti.Message(fd.MapValue().Index()), m.NewField(fd).Map().NewValue().Message().Type())
189				}
190			}
191		default:
192			if fd.Enum() != nil {
193				checkEnumDesc(i,
194					"mt.Enum(i).Descriptor()", "fd.Enum()",
195					mt.Enum(i).Descriptor(), fd.Enum())
196			}
197			if fd.Message() != nil {
198				checkMessageDesc(i,
199					"mt.Message(i).Descriptor()", "fd.Message()",
200					mt.Message(i).Descriptor(), fd.Message())
201				checkMessageType(i,
202					"mt.Message(i)", "m.NewField(fd).Message().Type()",
203					mt.Message(i), m.NewField(fd).Message().Type())
204			}
205		}
206	}
207}
208
209// testField exercises set/get/has/clear of a field.
210func testField(t testing.TB, m pref.Message, fd pref.FieldDescriptor) {
211	name := fd.FullName()
212	num := fd.Number()
213
214	switch {
215	case fd.IsList():
216		testFieldList(t, m, fd)
217	case fd.IsMap():
218		testFieldMap(t, m, fd)
219	case fd.Message() != nil:
220	default:
221		if got, want := m.NewField(fd), fd.Default(); !valueEqual(got, want) {
222			t.Errorf("Message.NewField(%v) = %v, want default value %v", name, formatValue(got), formatValue(want))
223		}
224		if fd.Kind() == pref.FloatKind || fd.Kind() == pref.DoubleKind {
225			testFieldFloat(t, m, fd)
226		}
227	}
228
229	// Set to a non-zero value, the zero value, different non-zero values.
230	for _, n := range []seed{1, 0, minVal, maxVal} {
231		v := newValue(m, fd, n, nil)
232		m.Set(fd, v)
233		wantHas := true
234		if n == 0 {
235			if fd.Syntax() == pref.Proto3 && fd.Message() == nil {
236				wantHas = false
237			}
238			if fd.IsExtension() {
239				wantHas = true
240			}
241			if fd.Cardinality() == pref.Repeated {
242				wantHas = false
243			}
244			if fd.ContainingOneof() != nil {
245				wantHas = true
246			}
247		}
248		if fd.Syntax() == pref.Proto3 && fd.Cardinality() != pref.Repeated && fd.ContainingOneof() == nil && fd.Kind() == pref.EnumKind && v.Enum() == 0 {
249			wantHas = false
250		}
251		if got, want := m.Has(fd), wantHas; got != want {
252			t.Errorf("after setting %q to %v:\nMessage.Has(%v) = %v, want %v", name, formatValue(v), num, got, want)
253		}
254		if got, want := m.Get(fd), v; !valueEqual(got, want) {
255			t.Errorf("after setting %q:\nMessage.Get(%v) = %v, want %v", name, num, formatValue(got), formatValue(want))
256		}
257		found := false
258		m.Range(func(d pref.FieldDescriptor, got pref.Value) bool {
259			if fd != d {
260				return true
261			}
262			found = true
263			if want := v; !valueEqual(got, want) {
264				t.Errorf("after setting %q:\nMessage.Range got value %v, want %v", name, formatValue(got), formatValue(want))
265			}
266			return true
267		})
268		if got, want := wantHas, found; got != want {
269			t.Errorf("after setting %q:\nMessageRange saw field: %v, want %v", name, got, want)
270		}
271	}
272
273	m.Clear(fd)
274	if got, want := m.Has(fd), false; got != want {
275		t.Errorf("after clearing %q:\nMessage.Has(%v) = %v, want %v", name, num, got, want)
276	}
277	switch {
278	case fd.IsList():
279		if got := m.Get(fd); got.List().Len() != 0 {
280			t.Errorf("after clearing %q:\nMessage.Get(%v) = %v, want empty list", name, num, formatValue(got))
281		}
282	case fd.IsMap():
283		if got := m.Get(fd); got.Map().Len() != 0 {
284			t.Errorf("after clearing %q:\nMessage.Get(%v) = %v, want empty map", name, num, formatValue(got))
285		}
286	case fd.Message() == nil:
287		if got, want := m.Get(fd), fd.Default(); !valueEqual(got, want) {
288			t.Errorf("after clearing %q:\nMessage.Get(%v) = %v, want default %v", name, num, formatValue(got), formatValue(want))
289		}
290	}
291
292	// Set to the default value.
293	switch {
294	case fd.IsList() || fd.IsMap():
295		m.Set(fd, m.Mutable(fd))
296		if got, want := m.Has(fd), (fd.IsExtension() && fd.Cardinality() != pref.Repeated) || fd.ContainingOneof() != nil; got != want {
297			t.Errorf("after setting %q to default:\nMessage.Has(%v) = %v, want %v", name, num, got, want)
298		}
299	case fd.Message() == nil:
300		m.Set(fd, m.Get(fd))
301		if got, want := m.Get(fd), fd.Default(); !valueEqual(got, want) {
302			t.Errorf("after setting %q to default:\nMessage.Get(%v) = %v, want default %v", name, num, formatValue(got), formatValue(want))
303		}
304	}
305	m.Clear(fd)
306
307	// Set to the wrong type.
308	v := pref.ValueOfString("")
309	if fd.Kind() == pref.StringKind {
310		v = pref.ValueOfInt32(0)
311	}
312	if !panics(func() {
313		m.Set(fd, v)
314	}) {
315		t.Errorf("setting %v to %T succeeds, want panic", name, v.Interface())
316	}
317}
318
319// testFieldMap tests set/get/has/clear of entries in a map field.
320func testFieldMap(t testing.TB, m pref.Message, fd pref.FieldDescriptor) {
321	name := fd.FullName()
322	num := fd.Number()
323
324	// New values.
325	m.Clear(fd) // start with an empty map
326	mapv := m.Get(fd).Map()
327	if mapv.IsValid() {
328		t.Errorf("after clearing field: message.Get(%v).IsValid() = true, want false", name)
329	}
330	if got, want := mapv.NewValue(), newMapValue(fd, mapv, 0, nil); !valueEqual(got, want) {
331		t.Errorf("message.Get(%v).NewValue() = %v, want %v", name, formatValue(got), formatValue(want))
332	}
333	if !panics(func() {
334		m.Set(fd, pref.ValueOfMap(mapv))
335	}) {
336		t.Errorf("message.Set(%v, <invalid>) does not panic", name)
337	}
338	if !panics(func() {
339		mapv.Set(newMapKey(fd, 0), newMapValue(fd, mapv, 0, nil))
340	}) {
341		t.Errorf("message.Get(%v).Set(...) of invalid map does not panic", name)
342	}
343	mapv = m.Mutable(fd).Map() // mutable map
344	if !mapv.IsValid() {
345		t.Errorf("message.Mutable(%v).IsValid() = false, want true", name)
346	}
347	if got, want := mapv.NewValue(), newMapValue(fd, mapv, 0, nil); !valueEqual(got, want) {
348		t.Errorf("message.Mutable(%v).NewValue() = %v, want %v", name, formatValue(got), formatValue(want))
349	}
350
351	// Add values.
352	want := make(testMap)
353	for i, n := range []seed{1, 0, minVal, maxVal} {
354		if got, want := m.Has(fd), i > 0; got != want {
355			t.Errorf("after inserting %d elements to %q:\nMessage.Has(%v) = %v, want %v", i, name, num, got, want)
356		}
357
358		k := newMapKey(fd, n)
359		v := newMapValue(fd, mapv, n, nil)
360		mapv.Set(k, v)
361		want.Set(k, v)
362		if got, want := m.Get(fd), pref.ValueOfMap(want); !valueEqual(got, want) {
363			t.Errorf("after inserting %d elements to %q:\nMessage.Get(%v) = %v, want %v", i, name, num, formatValue(got), formatValue(want))
364		}
365	}
366
367	// Set values.
368	want.Range(func(k pref.MapKey, v pref.Value) bool {
369		nv := newMapValue(fd, mapv, 10, nil)
370		mapv.Set(k, nv)
371		want.Set(k, nv)
372		if got, want := m.Get(fd), pref.ValueOfMap(want); !valueEqual(got, want) {
373			t.Errorf("after setting element %v of %q:\nMessage.Get(%v) = %v, want %v", formatValue(k.Value()), name, num, formatValue(got), formatValue(want))
374		}
375		return true
376	})
377
378	// Clear values.
379	want.Range(func(k pref.MapKey, v pref.Value) bool {
380		mapv.Clear(k)
381		want.Clear(k)
382		if got, want := m.Has(fd), want.Len() > 0; got != want {
383			t.Errorf("after clearing elements of %q:\nMessage.Has(%v) = %v, want %v", name, num, got, want)
384		}
385		if got, want := m.Get(fd), pref.ValueOfMap(want); !valueEqual(got, want) {
386			t.Errorf("after clearing elements of %q:\nMessage.Get(%v) = %v, want %v", name, num, formatValue(got), formatValue(want))
387		}
388		return true
389	})
390	if mapv := m.Get(fd).Map(); mapv.IsValid() {
391		t.Errorf("after clearing all elements: message.Get(%v).IsValid() = true, want false %v", name, formatValue(pref.ValueOfMap(mapv)))
392	}
393
394	// Non-existent map keys.
395	missingKey := newMapKey(fd, 1)
396	if got, want := mapv.Has(missingKey), false; got != want {
397		t.Errorf("non-existent map key in %q: Map.Has(%v) = %v, want %v", name, formatValue(missingKey.Value()), got, want)
398	}
399	if got, want := mapv.Get(missingKey).IsValid(), false; got != want {
400		t.Errorf("non-existent map key in %q: Map.Get(%v).IsValid() = %v, want %v", name, formatValue(missingKey.Value()), got, want)
401	}
402	mapv.Clear(missingKey) // noop
403
404	// Mutable.
405	if fd.MapValue().Message() == nil {
406		if !panics(func() {
407			mapv.Mutable(newMapKey(fd, 1))
408		}) {
409			t.Errorf("Mutable on %q succeeds, want panic", name)
410		}
411	} else {
412		k := newMapKey(fd, 1)
413		v := mapv.Mutable(k)
414		if got, want := mapv.Len(), 1; got != want {
415			t.Errorf("after Mutable on %q, Map.Len() = %v, want %v", name, got, want)
416		}
417		populateMessage(v.Message(), 1, nil)
418		if !valueEqual(mapv.Get(k), v) {
419			t.Errorf("after Mutable on %q, changing new mutable value does not change map entry", name)
420		}
421		mapv.Clear(k)
422	}
423}
424
425type testMap map[interface{}]pref.Value
426
427func (m testMap) Get(k pref.MapKey) pref.Value     { return m[k.Interface()] }
428func (m testMap) Set(k pref.MapKey, v pref.Value)  { m[k.Interface()] = v }
429func (m testMap) Has(k pref.MapKey) bool           { return m.Get(k).IsValid() }
430func (m testMap) Clear(k pref.MapKey)              { delete(m, k.Interface()) }
431func (m testMap) Mutable(k pref.MapKey) pref.Value { panic("unimplemented") }
432func (m testMap) Len() int                         { return len(m) }
433func (m testMap) NewValue() pref.Value             { panic("unimplemented") }
434func (m testMap) Range(f func(pref.MapKey, pref.Value) bool) {
435	for k, v := range m {
436		if !f(pref.ValueOf(k).MapKey(), v) {
437			return
438		}
439	}
440}
441func (m testMap) IsValid() bool { return true }
442
443// testFieldList exercises set/get/append/truncate of values in a list.
444func testFieldList(t testing.TB, m pref.Message, fd pref.FieldDescriptor) {
445	name := fd.FullName()
446	num := fd.Number()
447
448	m.Clear(fd) // start with an empty list
449	list := m.Get(fd).List()
450	if list.IsValid() {
451		t.Errorf("message.Get(%v).IsValid() = true, want false", name)
452	}
453	if !panics(func() {
454		m.Set(fd, pref.ValueOfList(list))
455	}) {
456		t.Errorf("message.Set(%v, <invalid>) does not panic", name)
457	}
458	if !panics(func() {
459		list.Append(newListElement(fd, list, 0, nil))
460	}) {
461		t.Errorf("message.Get(%v).Append(...) of invalid list does not panic", name)
462	}
463	if got, want := list.NewElement(), newListElement(fd, list, 0, nil); !valueEqual(got, want) {
464		t.Errorf("message.Get(%v).NewElement() = %v, want %v", name, formatValue(got), formatValue(want))
465	}
466	list = m.Mutable(fd).List() // mutable list
467	if !list.IsValid() {
468		t.Errorf("message.Get(%v).IsValid() = false, want true", name)
469	}
470	if got, want := list.NewElement(), newListElement(fd, list, 0, nil); !valueEqual(got, want) {
471		t.Errorf("message.Mutable(%v).NewElement() = %v, want %v", name, formatValue(got), formatValue(want))
472	}
473
474	// Append values.
475	var want pref.List = &testList{}
476	for i, n := range []seed{1, 0, minVal, maxVal} {
477		if got, want := m.Has(fd), i > 0; got != want {
478			t.Errorf("after appending %d elements to %q:\nMessage.Has(%v) = %v, want %v", i, name, num, got, want)
479		}
480		v := newListElement(fd, list, n, nil)
481		want.Append(v)
482		list.Append(v)
483
484		if got, want := m.Get(fd), pref.ValueOfList(want); !valueEqual(got, want) {
485			t.Errorf("after appending %d elements to %q:\nMessage.Get(%v) = %v, want %v", i+1, name, num, formatValue(got), formatValue(want))
486		}
487	}
488
489	// Set values.
490	for i := 0; i < want.Len(); i++ {
491		v := newListElement(fd, list, seed(i+10), nil)
492		want.Set(i, v)
493		list.Set(i, v)
494		if got, want := m.Get(fd), pref.ValueOfList(want); !valueEqual(got, want) {
495			t.Errorf("after setting element %d of %q:\nMessage.Get(%v) = %v, want %v", i, name, num, formatValue(got), formatValue(want))
496		}
497	}
498
499	// Truncate.
500	for want.Len() > 0 {
501		n := want.Len() - 1
502		want.Truncate(n)
503		list.Truncate(n)
504		if got, want := m.Has(fd), want.Len() > 0; got != want {
505			t.Errorf("after truncating %q to %d:\nMessage.Has(%v) = %v, want %v", name, n, num, got, want)
506		}
507		if got, want := m.Get(fd), pref.ValueOfList(want); !valueEqual(got, want) {
508			t.Errorf("after truncating %q to %d:\nMessage.Get(%v) = %v, want %v", name, n, num, formatValue(got), formatValue(want))
509		}
510	}
511
512	// AppendMutable.
513	if fd.Message() == nil {
514		if !panics(func() {
515			list.AppendMutable()
516		}) {
517			t.Errorf("AppendMutable on %q succeeds, want panic", name)
518		}
519	} else {
520		v := list.AppendMutable()
521		if got, want := list.Len(), 1; got != want {
522			t.Errorf("after AppendMutable on %q, list.Len() = %v, want %v", name, got, want)
523		}
524		populateMessage(v.Message(), 1, nil)
525		if !valueEqual(list.Get(0), v) {
526			t.Errorf("after AppendMutable on %q, changing new mutable value does not change list item 0", name)
527		}
528		want.Truncate(0)
529	}
530}
531
532type testList struct {
533	a []pref.Value
534}
535
536func (l *testList) Append(v pref.Value)       { l.a = append(l.a, v) }
537func (l *testList) AppendMutable() pref.Value { panic("unimplemented") }
538func (l *testList) Get(n int) pref.Value      { return l.a[n] }
539func (l *testList) Len() int                  { return len(l.a) }
540func (l *testList) Set(n int, v pref.Value)   { l.a[n] = v }
541func (l *testList) Truncate(n int)            { l.a = l.a[:n] }
542func (l *testList) NewElement() pref.Value    { panic("unimplemented") }
543func (l *testList) IsValid() bool             { return true }
544
545// testFieldFloat exercises some interesting floating-point scalar field values.
546func testFieldFloat(t testing.TB, m pref.Message, fd pref.FieldDescriptor) {
547	name := fd.FullName()
548	num := fd.Number()
549
550	for _, v := range []float64{math.Inf(-1), math.Inf(1), math.NaN(), math.Copysign(0, -1)} {
551		var val pref.Value
552		if fd.Kind() == pref.FloatKind {
553			val = pref.ValueOfFloat32(float32(v))
554		} else {
555			val = pref.ValueOfFloat64(float64(v))
556		}
557		m.Set(fd, val)
558		// Note that Has is true for -0.
559		if got, want := m.Has(fd), true; got != want {
560			t.Errorf("after setting %v to %v: Message.Has(%v) = %v, want %v", name, v, num, got, want)
561		}
562		if got, want := m.Get(fd), val; !valueEqual(got, want) {
563			t.Errorf("after setting %v: Message.Get(%v) = %v, want %v", name, num, formatValue(got), formatValue(want))
564		}
565	}
566}
567
568// testOneof tests the behavior of fields in a oneof.
569func testOneof(t testing.TB, m pref.Message, od pref.OneofDescriptor) {
570	for _, mutable := range []bool{false, true} {
571		for i := 0; i < od.Fields().Len(); i++ {
572			fda := od.Fields().Get(i)
573			if mutable {
574				// Set fields by requesting a mutable reference.
575				if !fda.IsMap() && !fda.IsList() && fda.Message() == nil {
576					continue
577				}
578				_ = m.Mutable(fda)
579			} else {
580				// Set fields explicitly.
581				m.Set(fda, newValue(m, fda, 1, nil))
582			}
583			if got, want := m.WhichOneof(od), fda; got != want {
584				t.Errorf("after setting oneof field %q:\nWhichOneof(%q) = %v, want %v", fda.FullName(), fda.Name(), got, want)
585			}
586			for j := 0; j < od.Fields().Len(); j++ {
587				fdb := od.Fields().Get(j)
588				if got, want := m.Has(fdb), i == j; got != want {
589					t.Errorf("after setting oneof field %q:\nGet(%q) = %v, want %v", fda.FullName(), fdb.FullName(), got, want)
590				}
591			}
592		}
593	}
594}
595
596// testUnknown tests the behavior of unknown fields.
597func testUnknown(t testing.TB, m pref.Message) {
598	var b []byte
599	b = protowire.AppendTag(b, 1000, protowire.VarintType)
600	b = protowire.AppendVarint(b, 1001)
601	m.SetUnknown(pref.RawFields(b))
602	if got, want := []byte(m.GetUnknown()), b; !bytes.Equal(got, want) {
603		t.Errorf("after setting unknown fields:\nGetUnknown() = %v, want %v", got, want)
604	}
605}
606
607func formatValue(v pref.Value) string {
608	switch v := v.Interface().(type) {
609	case pref.List:
610		var buf bytes.Buffer
611		buf.WriteString("list[")
612		for i := 0; i < v.Len(); i++ {
613			if i > 0 {
614				buf.WriteString(" ")
615			}
616			buf.WriteString(formatValue(v.Get(i)))
617		}
618		buf.WriteString("]")
619		return buf.String()
620	case pref.Map:
621		var buf bytes.Buffer
622		buf.WriteString("map[")
623		var keys []pref.MapKey
624		v.Range(func(k pref.MapKey, v pref.Value) bool {
625			keys = append(keys, k)
626			return true
627		})
628		sort.Slice(keys, func(i, j int) bool {
629			return keys[i].String() < keys[j].String()
630		})
631		for i, k := range keys {
632			if i > 0 {
633				buf.WriteString(" ")
634			}
635			buf.WriteString(formatValue(k.Value()))
636			buf.WriteString(":")
637			buf.WriteString(formatValue(v.Get(k)))
638		}
639		buf.WriteString("]")
640		return buf.String()
641	case pref.Message:
642		b, err := prototext.Marshal(v.Interface())
643		if err != nil {
644			return fmt.Sprintf("<%v>", err)
645		}
646		return fmt.Sprintf("%v{%v}", v.Descriptor().FullName(), string(b))
647	case string:
648		return fmt.Sprintf("%q", v)
649	default:
650		return fmt.Sprint(v)
651	}
652}
653
654func valueEqual(a, b pref.Value) bool {
655	ai, bi := a.Interface(), b.Interface()
656	switch ai.(type) {
657	case pref.Message:
658		return proto.Equal(
659			a.Message().Interface(),
660			b.Message().Interface(),
661		)
662	case pref.List:
663		lista, listb := a.List(), b.List()
664		if lista.Len() != listb.Len() {
665			return false
666		}
667		for i := 0; i < lista.Len(); i++ {
668			if !valueEqual(lista.Get(i), listb.Get(i)) {
669				return false
670			}
671		}
672		return true
673	case pref.Map:
674		mapa, mapb := a.Map(), b.Map()
675		if mapa.Len() != mapb.Len() {
676			return false
677		}
678		equal := true
679		mapa.Range(func(k pref.MapKey, v pref.Value) bool {
680			if !valueEqual(v, mapb.Get(k)) {
681				equal = false
682				return false
683			}
684			return true
685		})
686		return equal
687	case []byte:
688		return bytes.Equal(a.Bytes(), b.Bytes())
689	case float32:
690		// NaNs are equal, but must be the same NaN.
691		return math.Float32bits(ai.(float32)) == math.Float32bits(bi.(float32))
692	case float64:
693		// NaNs are equal, but must be the same NaN.
694		return math.Float64bits(ai.(float64)) == math.Float64bits(bi.(float64))
695	default:
696		return ai == bi
697	}
698}
699
700// A seed is used to vary the content of a value.
701//
702// A seed of 0 is the zero value. Messages do not have a zero-value; a 0-seeded messages
703// is unpopulated.
704//
705// A seed of minVal or maxVal is the least or greatest value of the value type.
706type seed int
707
708const (
709	minVal seed = -1
710	maxVal seed = -2
711)
712
713// newSeed creates new seed values from a base, for example to create seeds for the
714// elements in a list. If the input seed is minVal or maxVal, so is the output.
715func newSeed(n seed, adjust ...int) seed {
716	switch n {
717	case minVal, maxVal:
718		return n
719	}
720	for _, a := range adjust {
721		n = 10*n + seed(a)
722	}
723	return n
724}
725
726// newValue returns a new value assignable to a field.
727//
728// The stack parameter is used to avoid infinite recursion when populating circular
729// data structures.
730func newValue(m pref.Message, fd pref.FieldDescriptor, n seed, stack []pref.MessageDescriptor) pref.Value {
731	switch {
732	case fd.IsList():
733		if n == 0 {
734			return m.New().Mutable(fd)
735		}
736		list := m.NewField(fd).List()
737		list.Append(newListElement(fd, list, 0, stack))
738		list.Append(newListElement(fd, list, minVal, stack))
739		list.Append(newListElement(fd, list, maxVal, stack))
740		list.Append(newListElement(fd, list, n, stack))
741		return pref.ValueOfList(list)
742	case fd.IsMap():
743		if n == 0 {
744			return m.New().Mutable(fd)
745		}
746		mapv := m.NewField(fd).Map()
747		mapv.Set(newMapKey(fd, 0), newMapValue(fd, mapv, 0, stack))
748		mapv.Set(newMapKey(fd, minVal), newMapValue(fd, mapv, minVal, stack))
749		mapv.Set(newMapKey(fd, maxVal), newMapValue(fd, mapv, maxVal, stack))
750		mapv.Set(newMapKey(fd, n), newMapValue(fd, mapv, newSeed(n, 0), stack))
751		return pref.ValueOfMap(mapv)
752	case fd.Message() != nil:
753		return populateMessage(m.NewField(fd).Message(), n, stack)
754	default:
755		return newScalarValue(fd, n)
756	}
757}
758
759func newListElement(fd pref.FieldDescriptor, list pref.List, n seed, stack []pref.MessageDescriptor) pref.Value {
760	if fd.Message() == nil {
761		return newScalarValue(fd, n)
762	}
763	return populateMessage(list.NewElement().Message(), n, stack)
764}
765
766func newMapKey(fd pref.FieldDescriptor, n seed) pref.MapKey {
767	kd := fd.MapKey()
768	return newScalarValue(kd, n).MapKey()
769}
770
771func newMapValue(fd pref.FieldDescriptor, mapv pref.Map, n seed, stack []pref.MessageDescriptor) pref.Value {
772	vd := fd.MapValue()
773	if vd.Message() == nil {
774		return newScalarValue(vd, n)
775	}
776	return populateMessage(mapv.NewValue().Message(), n, stack)
777}
778
779func newScalarValue(fd pref.FieldDescriptor, n seed) pref.Value {
780	switch fd.Kind() {
781	case pref.BoolKind:
782		return pref.ValueOfBool(n != 0)
783	case pref.EnumKind:
784		vals := fd.Enum().Values()
785		var i int
786		switch n {
787		case minVal:
788			i = 0
789		case maxVal:
790			i = vals.Len() - 1
791		default:
792			i = int(n) % vals.Len()
793		}
794		return pref.ValueOfEnum(vals.Get(i).Number())
795	case pref.Int32Kind, pref.Sint32Kind, pref.Sfixed32Kind:
796		switch n {
797		case minVal:
798			return pref.ValueOfInt32(math.MinInt32)
799		case maxVal:
800			return pref.ValueOfInt32(math.MaxInt32)
801		default:
802			return pref.ValueOfInt32(int32(n))
803		}
804	case pref.Uint32Kind, pref.Fixed32Kind:
805		switch n {
806		case minVal:
807			// Only use 0 for the zero value.
808			return pref.ValueOfUint32(1)
809		case maxVal:
810			return pref.ValueOfUint32(math.MaxInt32)
811		default:
812			return pref.ValueOfUint32(uint32(n))
813		}
814	case pref.Int64Kind, pref.Sint64Kind, pref.Sfixed64Kind:
815		switch n {
816		case minVal:
817			return pref.ValueOfInt64(math.MinInt64)
818		case maxVal:
819			return pref.ValueOfInt64(math.MaxInt64)
820		default:
821			return pref.ValueOfInt64(int64(n))
822		}
823	case pref.Uint64Kind, pref.Fixed64Kind:
824		switch n {
825		case minVal:
826			// Only use 0 for the zero value.
827			return pref.ValueOfUint64(1)
828		case maxVal:
829			return pref.ValueOfUint64(math.MaxInt64)
830		default:
831			return pref.ValueOfUint64(uint64(n))
832		}
833	case pref.FloatKind:
834		switch n {
835		case minVal:
836			return pref.ValueOfFloat32(math.SmallestNonzeroFloat32)
837		case maxVal:
838			return pref.ValueOfFloat32(math.MaxFloat32)
839		default:
840			return pref.ValueOfFloat32(1.5 * float32(n))
841		}
842	case pref.DoubleKind:
843		switch n {
844		case minVal:
845			return pref.ValueOfFloat64(math.SmallestNonzeroFloat64)
846		case maxVal:
847			return pref.ValueOfFloat64(math.MaxFloat64)
848		default:
849			return pref.ValueOfFloat64(1.5 * float64(n))
850		}
851	case pref.StringKind:
852		if n == 0 {
853			return pref.ValueOfString("")
854		}
855		return pref.ValueOfString(fmt.Sprintf("%d", n))
856	case pref.BytesKind:
857		if n == 0 {
858			return pref.ValueOfBytes(nil)
859		}
860		return pref.ValueOfBytes([]byte{byte(n >> 24), byte(n >> 16), byte(n >> 8), byte(n)})
861	}
862	panic("unhandled kind")
863}
864
865func populateMessage(m pref.Message, n seed, stack []pref.MessageDescriptor) pref.Value {
866	if n == 0 {
867		return pref.ValueOfMessage(m)
868	}
869	md := m.Descriptor()
870	for _, x := range stack {
871		if md == x {
872			return pref.ValueOfMessage(m)
873		}
874	}
875	stack = append(stack, md)
876	for i := 0; i < md.Fields().Len(); i++ {
877		fd := md.Fields().Get(i)
878		if fd.IsWeak() {
879			continue
880		}
881		m.Set(fd, newValue(m, fd, newSeed(n, i), stack))
882	}
883	return pref.ValueOfMessage(m)
884}
885
886func panics(f func()) (didPanic bool) {
887	defer func() {
888		if err := recover(); err != nil {
889			didPanic = true
890		}
891	}()
892	f()
893	return false
894}
895