1// Copyright 2020 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package fieldmaskpb_test
6
7import (
8	"testing"
9
10	"github.com/google/go-cmp/cmp"
11	"github.com/google/go-cmp/cmp/cmpopts"
12	"google.golang.org/protobuf/proto"
13
14	testpb "google.golang.org/protobuf/internal/testprotos/test"
15	fmpb "google.golang.org/protobuf/types/known/fieldmaskpb"
16)
17
18func TestAppend(t *testing.T) {
19	tests := []struct {
20		inMessage proto.Message
21		inPaths   []string
22		wantPaths []string
23		wantError error
24	}{{
25		inMessage: (*fmpb.FieldMask)(nil),
26		inPaths:   []string{},
27		wantPaths: []string{},
28	}, {
29		inMessage: (*fmpb.FieldMask)(nil),
30		inPaths:   []string{"paths", "paths"},
31		wantPaths: []string{"paths", "paths"},
32	}, {
33		inMessage: (*fmpb.FieldMask)(nil),
34		inPaths:   []string{"paths", "<INVALID>", "paths"},
35		wantPaths: []string{"paths"},
36		wantError: cmpopts.AnyError,
37	}, {
38		inMessage: (*testpb.TestAllTypes)(nil),
39		inPaths:   []string{"optional_int32", "OptionalGroup.optional_nested_message", "map_uint32_uint32", "map_string_nested_message.corecursive", "oneof_bool"},
40		wantPaths: []string{"optional_int32", "OptionalGroup.optional_nested_message", "map_uint32_uint32", "map_string_nested_message.corecursive", "oneof_bool"},
41	}, {
42		inMessage: (*testpb.TestAllTypes)(nil),
43		inPaths:   []string{"optional_nested_message", "optional_nested_message.corecursive", "optional_nested_message.corecursive.optional_nested_message", "optional_nested_message.corecursive.optional_nested_message.corecursive"},
44		wantPaths: []string{"optional_nested_message", "optional_nested_message.corecursive", "optional_nested_message.corecursive.optional_nested_message", "optional_nested_message.corecursive.optional_nested_message.corecursive"},
45	}, {
46		inMessage: (*testpb.TestAllTypes)(nil),
47		inPaths:   []string{"optional_int32", "optional_nested_message.corecursive.optional_int64", "optional_nested_message.corecursive.<INVALID>", "optional_int64"},
48		wantPaths: []string{"optional_int32", "optional_nested_message.corecursive.optional_int64"},
49		wantError: cmpopts.AnyError,
50	}, {
51		inMessage: (*testpb.TestAllTypes)(nil),
52		inPaths:   []string{"optional_int32", "optional_nested_message.corecursive.oneof_uint32", "optional_nested_message.oneof_field", "optional_int64"},
53		wantPaths: []string{"optional_int32", "optional_nested_message.corecursive.oneof_uint32"},
54		wantError: cmpopts.AnyError,
55	}}
56
57	for _, tt := range tests {
58		t.Run("", func(t *testing.T) {
59			var mask fmpb.FieldMask
60			gotError := mask.Append(tt.inMessage, tt.inPaths...)
61			gotPaths := mask.GetPaths()
62			if diff := cmp.Diff(tt.wantPaths, gotPaths, cmpopts.EquateEmpty()); diff != "" {
63				t.Errorf("Append() paths mismatch (-want +got):\n%s", diff)
64			}
65			if diff := cmp.Diff(tt.wantError, gotError, cmpopts.EquateErrors()); diff != "" {
66				t.Errorf("Append() error mismatch (-want +got):\n%s", diff)
67			}
68		})
69	}
70}
71
72func TestCombine(t *testing.T) {
73	tests := []struct {
74		in            [][]string
75		wantUnion     []string
76		wantIntersect []string
77	}{{
78		in: [][]string{
79			{},
80			{},
81		},
82		wantUnion:     []string{},
83		wantIntersect: []string{},
84	}, {
85		in: [][]string{
86			{"a"},
87			{},
88		},
89		wantUnion:     []string{"a"},
90		wantIntersect: []string{},
91	}, {
92		in: [][]string{
93			{"a"},
94			{"a"},
95		},
96		wantUnion:     []string{"a"},
97		wantIntersect: []string{"a"},
98	}, {
99		in: [][]string{
100			{"a"},
101			{"b"},
102			{"c"},
103		},
104		wantUnion:     []string{"a", "b", "c"},
105		wantIntersect: []string{},
106	}, {
107		in: [][]string{
108			{"a", "b"},
109			{"b.b"},
110			{"b"},
111			{"b", "a.A"},
112			{"b", "c", "c.a", "c.b"},
113		},
114		wantUnion:     []string{"a", "b", "c"},
115		wantIntersect: []string{"b.b"},
116	}, {
117		in: [][]string{
118			{"a.b", "a.c.d"},
119			{"a"},
120		},
121		wantUnion:     []string{"a"},
122		wantIntersect: []string{"a.b", "a.c.d"},
123	}, {
124		in: [][]string{
125			{},
126			{"a.b", "a.c", "d"},
127		},
128		wantUnion:     []string{"a.b", "a.c", "d"},
129		wantIntersect: []string{},
130	}}
131
132	for _, tt := range tests {
133		t.Run("", func(t *testing.T) {
134			var masks []*fmpb.FieldMask
135			for _, paths := range tt.in {
136				masks = append(masks, &fmpb.FieldMask{Paths: paths})
137			}
138
139			union := fmpb.Union(masks[0], masks[1], masks[2:]...)
140			gotUnion := union.GetPaths()
141			if diff := cmp.Diff(tt.wantUnion, gotUnion, cmpopts.EquateEmpty()); diff != "" {
142				t.Errorf("Union() mismatch (-want +got):\n%s", diff)
143			}
144
145			intersect := fmpb.Intersect(masks[0], masks[1], masks[2:]...)
146			gotIntersect := intersect.GetPaths()
147			if diff := cmp.Diff(tt.wantIntersect, gotIntersect, cmpopts.EquateEmpty()); diff != "" {
148				t.Errorf("Intersect() mismatch (-want +got):\n%s", diff)
149			}
150		})
151	}
152}
153
154func TestNormalize(t *testing.T) {
155	tests := []struct {
156		in   []string
157		want []string
158	}{{
159		in:   []string{},
160		want: []string{},
161	}, {
162		in:   []string{"a"},
163		want: []string{"a"},
164	}, {
165		in:   []string{"foo", "foo.bar", "foo.baz"},
166		want: []string{"foo"},
167	}, {
168		in:   []string{"foo.bar", "foo.baz"},
169		want: []string{"foo.bar", "foo.baz"},
170	}, {
171		in:   []string{"", "a.", ".b", "a.b", ".", "", "a.", ".b", "a.b", "."},
172		want: []string{"", "a.", "a.b"},
173	}, {
174		in:   []string{"e.a", "e.b", "e.c", "e.d", "e.f", "e.g", "e.b.a", "e$c", "e.b.c"},
175		want: []string{"e.a", "e.b", "e.c", "e.d", "e.f", "e.g", "e$c"},
176	}, {
177		in:   []string{"a", "aa", "aaa", "a$", "AAA", "aA.a", "a.a", "a", "aa", "aaa", "a$", "AAA", "aA.a"},
178		want: []string{"AAA", "a", "aA.a", "aa", "aaa", "a$"},
179	}, {
180		in:   []string{"a.b", "aa.bb.cc", ".", "a$b", "aa", "a.", "a", "b.c.d", ".a", "", "a$", "a$", "a.b", "a", "a.bb", ""},
181		want: []string{"", "a", "aa", "a$", "a$b", "b.c.d"},
182	}}
183
184	for _, tt := range tests {
185		t.Run("", func(t *testing.T) {
186			mask := &fmpb.FieldMask{
187				Paths: append([]string(nil), tt.in...),
188			}
189			mask.Normalize()
190			got := mask.GetPaths()
191			if diff := cmp.Diff(tt.want, got, cmpopts.EquateEmpty()); diff != "" {
192				t.Errorf("Normalize() mismatch (-want +got):\n%s", diff)
193			}
194		})
195	}
196}
197