1// Copyright 2020 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package protorange
6
7import (
8	"testing"
9	"time"
10
11	"github.com/google/go-cmp/cmp"
12	"github.com/google/go-cmp/cmp/cmpopts"
13	"google.golang.org/protobuf/proto"
14	"google.golang.org/protobuf/reflect/protopath"
15	"google.golang.org/protobuf/reflect/protoreflect"
16	"google.golang.org/protobuf/reflect/protoregistry"
17	"google.golang.org/protobuf/testing/protocmp"
18
19	newspb "google.golang.org/protobuf/internal/testprotos/news"
20	anypb "google.golang.org/protobuf/types/known/anypb"
21	timestamppb "google.golang.org/protobuf/types/known/timestamppb"
22)
23
24func mustMarshal(m proto.Message) []byte {
25	b, err := proto.MarshalOptions{Deterministic: true}.Marshal(m)
26	if err != nil {
27		panic(err)
28	}
29	return b
30}
31
32var transformReflectValue = cmp.Transformer("", func(v protoreflect.Value) interface{} {
33	switch v := v.Interface().(type) {
34	case protoreflect.Message:
35		return v.Interface()
36	case protoreflect.Map:
37		ms := map[interface{}]protoreflect.Value{}
38		v.Range(func(k protoreflect.MapKey, v protoreflect.Value) bool {
39			ms[k.Interface()] = v
40			return true
41		})
42		return ms
43	case protoreflect.List:
44		ls := []protoreflect.Value{}
45		for i := 0; i < v.Len(); i++ {
46			ls = append(ls, v.Get(i))
47		}
48		return ls
49	default:
50		return v
51	}
52})
53
54func TestRange(t *testing.T) {
55	m2 := (&newspb.KeyValueAttachment{
56		Name: "checksums.txt",
57		Data: map[string]string{
58			"go1.10.src.tar.gz":         "07cbb9d0091b846c6aea40bf5bc0cea7",
59			"go1.10.darwin-amd64.pkg":   "cbb38bb6ff6ea86279e01745984445bf",
60			"go1.10.linux-amd64.tar.gz": "6b3d0e4a5c77352cf4275573817f7566",
61			"go1.10.windows-amd64.msi":  "57bda02030f58f5d2bf71943e1390123",
62		},
63	}).ProtoReflect()
64	m := (&newspb.Article{
65		Author:  "Brad Fitzpatrick",
66		Date:    timestamppb.New(time.Date(2018, time.February, 16, 0, 0, 0, 0, time.UTC)),
67		Title:   "Go 1.10 is released",
68		Content: "Happy Friday, happy weekend! Today the Go team is happy to announce the release of Go 1.10...",
69		Status:  newspb.Article_PUBLISHED,
70		Tags:    []string{"go1.10", "release"},
71		Attachments: []*anypb.Any{{
72			TypeUrl: "google.golang.org.KeyValueAttachment",
73			Value:   mustMarshal(m2.Interface()),
74		}},
75	}).ProtoReflect()
76
77	// Nil push and pop functions should not panic.
78	noop := func(protopath.Values) error { return nil }
79	Options{}.Range(m, nil, nil)
80	Options{}.Range(m, noop, nil)
81	Options{}.Range(m, nil, noop)
82
83	getByName := func(m protoreflect.Message, s protoreflect.Name) protoreflect.Value {
84		fds := m.Descriptor().Fields()
85		return m.Get(fds.ByName(s))
86	}
87
88	wantPaths := []string{
89		``,
90		`.author`,
91		`.date`,
92		`.date.seconds`,
93		`.title`,
94		`.content`,
95		`.attachments`,
96		`.attachments[0]`,
97		`.attachments[0].(google.golang.org.KeyValueAttachment)`,
98		`.attachments[0].(google.golang.org.KeyValueAttachment).name`,
99		`.attachments[0].(google.golang.org.KeyValueAttachment).data`,
100		`.attachments[0].(google.golang.org.KeyValueAttachment).data["go1.10.darwin-amd64.pkg"]`,
101		`.attachments[0].(google.golang.org.KeyValueAttachment).data["go1.10.linux-amd64.tar.gz"]`,
102		`.attachments[0].(google.golang.org.KeyValueAttachment).data["go1.10.src.tar.gz"]`,
103		`.attachments[0].(google.golang.org.KeyValueAttachment).data["go1.10.windows-amd64.msi"]`,
104		`.tags`,
105		`.tags[0]`,
106		`.tags[1]`,
107		`.status`,
108	}
109	wantValues := []protoreflect.Value{
110		protoreflect.ValueOfMessage(m),
111		getByName(m, "author"),
112		getByName(m, "date"),
113		getByName(getByName(m, "date").Message(), "seconds"),
114		getByName(m, `title`),
115		getByName(m, `content`),
116		getByName(m, `attachments`),
117		getByName(m, `attachments`).List().Get(0),
118		protoreflect.ValueOfMessage(m2),
119		getByName(m2, `name`),
120		getByName(m2, `data`),
121		getByName(m2, `data`).Map().Get(protoreflect.ValueOfString("go1.10.darwin-amd64.pkg").MapKey()),
122		getByName(m2, `data`).Map().Get(protoreflect.ValueOfString("go1.10.linux-amd64.tar.gz").MapKey()),
123		getByName(m2, `data`).Map().Get(protoreflect.ValueOfString("go1.10.src.tar.gz").MapKey()),
124		getByName(m2, `data`).Map().Get(protoreflect.ValueOfString("go1.10.windows-amd64.msi").MapKey()),
125		getByName(m, `tags`),
126		getByName(m, `tags`).List().Get(0),
127		getByName(m, `tags`).List().Get(1),
128		getByName(m, `status`),
129	}
130
131	tests := []struct {
132		resolver interface {
133			protoregistry.ExtensionTypeResolver
134			protoregistry.MessageTypeResolver
135		}
136
137		errorAt     int
138		breakAt     int
139		terminateAt int
140
141		wantPaths  []string
142		wantValues []protoreflect.Value
143		wantError  error
144	}{{
145		wantPaths:  wantPaths,
146		wantValues: wantValues,
147	}, {
148		resolver: (*protoregistry.Types)(nil),
149		wantPaths: append(append(wantPaths[:8:8],
150			`.attachments[0].type_url`,
151			`.attachments[0].value`,
152		), wantPaths[15:]...),
153		wantValues: append(append(wantValues[:8:8],
154			protoreflect.ValueOfString("google.golang.org.KeyValueAttachment"),
155			protoreflect.ValueOfBytes(mustMarshal(m2.Interface())),
156		), wantValues[15:]...),
157	}, {
158		errorAt:    5, // return error within newspb.Article
159		wantPaths:  wantPaths[:5],
160		wantValues: wantValues[:5],
161		wantError:  cmpopts.AnyError,
162	}, {
163		terminateAt: 11, // terminate within newspb.KeyValueAttachment
164		wantPaths:   wantPaths[:11],
165		wantValues:  wantValues[:11],
166	}, {
167		breakAt:    11, // break within newspb.KeyValueAttachment
168		wantPaths:  append(wantPaths[:11:11], wantPaths[15:]...),
169		wantValues: append(wantValues[:11:11], wantValues[15:]...),
170	}, {
171		errorAt:    17, // return error within newspb.Article.Tags
172		wantPaths:  wantPaths[:17],
173		wantValues: wantValues[:17],
174		wantError:  cmpopts.AnyError,
175	}, {
176		breakAt:    17, // break within newspb.Article.Tags
177		wantPaths:  append(wantPaths[:17:17], wantPaths[18:]...),
178		wantValues: append(wantValues[:17:17], wantValues[18:]...),
179	}, {
180		terminateAt: 17, // terminate within newspb.Article.Tags
181		wantPaths:   wantPaths[:17],
182		wantValues:  wantValues[:17],
183	}, {
184		errorAt:    13, // return error within newspb.KeyValueAttachment.Data
185		wantPaths:  wantPaths[:13],
186		wantValues: wantValues[:13],
187		wantError:  cmpopts.AnyError,
188	}, {
189		breakAt:    13, // break within newspb.KeyValueAttachment.Data
190		wantPaths:  append(wantPaths[:13:13], wantPaths[15:]...),
191		wantValues: append(wantValues[:13:13], wantValues[15:]...),
192	}, {
193		terminateAt: 13, // terminate within newspb.KeyValueAttachment.Data
194		wantPaths:   wantPaths[:13],
195		wantValues:  wantValues[:13],
196	}}
197	for _, tt := range tests {
198		t.Run("", func(t *testing.T) {
199			var gotPaths []string
200			var gotValues []protoreflect.Value
201			var stackPaths []string
202			var stackValues []protoreflect.Value
203			gotError := Options{
204				Stable:   true,
205				Resolver: tt.resolver,
206			}.Range(m,
207				func(p protopath.Values) error {
208					gotPaths = append(gotPaths, p.Path[1:].String())
209					stackPaths = append(stackPaths, p.Path[1:].String())
210					gotValues = append(gotValues, p.Index(-1).Value)
211					stackValues = append(stackValues, p.Index(-1).Value)
212					switch {
213					case tt.errorAt > 0 && tt.errorAt == len(gotPaths):
214						return cmpopts.AnyError
215					case tt.breakAt > 0 && tt.breakAt == len(gotPaths):
216						return Break
217					case tt.terminateAt > 0 && tt.terminateAt == len(gotPaths):
218						return Terminate
219					default:
220						return nil
221					}
222				},
223				func(p protopath.Values) error {
224					gotPath := p.Path[1:].String()
225					wantPath := stackPaths[len(stackPaths)-1]
226					if wantPath != gotPath {
227						t.Errorf("%d: pop path mismatch: got %v, want %v", len(gotPaths), gotPath, wantPath)
228					}
229					gotValue := p.Index(-1).Value
230					wantValue := stackValues[len(stackValues)-1]
231					if diff := cmp.Diff(wantValue, gotValue, transformReflectValue, protocmp.Transform()); diff != "" {
232						t.Errorf("%d: pop value mismatch (-want +got):\n%v", len(gotValues), diff)
233					}
234					stackPaths = stackPaths[:len(stackPaths)-1]
235					stackValues = stackValues[:len(stackValues)-1]
236					return nil
237				},
238			)
239			if n := len(stackPaths) + len(stackValues); n > 0 {
240				t.Errorf("stack mismatch: got %d unpopped items", n)
241			}
242			if diff := cmp.Diff(tt.wantPaths, gotPaths); diff != "" {
243				t.Errorf("paths mismatch (-want +got):\n%s", diff)
244			}
245			if diff := cmp.Diff(tt.wantValues, gotValues, transformReflectValue, protocmp.Transform()); diff != "" {
246				t.Errorf("values mismatch (-want +got):\n%s", diff)
247			}
248			if !cmp.Equal(gotError, tt.wantError, cmpopts.EquateErrors()) {
249				t.Errorf("error mismatch: got %v, want %v", gotError, tt.wantError)
250			}
251		})
252	}
253}
254