1// Copyright 2017 Google LLC
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//      http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15package firestore
16
17import (
18	"context"
19	"testing"
20
21	tspb "github.com/golang/protobuf/ptypes/timestamp"
22	pb "google.golang.org/genproto/googleapis/firestore/v1"
23	"google.golang.org/grpc/codes"
24	"google.golang.org/grpc/status"
25)
26
27var testClient = &Client{
28	projectID:  "projectID",
29	databaseID: "(default)",
30}
31
32func TestClientCollectionAndDoc(t *testing.T) {
33	coll1 := testClient.Collection("X")
34	db := "projects/projectID/databases/(default)"
35	wantc1 := &CollectionRef{
36		c:          testClient,
37		parentPath: db + "/documents",
38		selfPath:   "X",
39		Parent:     nil,
40		ID:         "X",
41		Path:       "projects/projectID/databases/(default)/documents/X",
42		Query: Query{
43			c:            testClient,
44			collectionID: "X",
45			path:         "projects/projectID/databases/(default)/documents/X",
46			parentPath:   db + "/documents",
47		},
48	}
49	if !testEqual(coll1, wantc1) {
50		t.Fatalf("got\n%+v\nwant\n%+v", coll1, wantc1)
51	}
52	doc1 := testClient.Doc("X/a")
53	wantd1 := &DocumentRef{
54		Parent:    coll1,
55		ID:        "a",
56		Path:      "projects/projectID/databases/(default)/documents/X/a",
57		shortPath: "X/a",
58	}
59
60	if !testEqual(doc1, wantd1) {
61		t.Fatalf("got %+v, want %+v", doc1, wantd1)
62	}
63	coll2 := testClient.Collection("X/a/Y")
64	parentPath := "projects/projectID/databases/(default)/documents/X/a"
65	wantc2 := &CollectionRef{
66		c:          testClient,
67		parentPath: parentPath,
68		selfPath:   "X/a/Y",
69		Parent:     doc1,
70		ID:         "Y",
71		Path:       "projects/projectID/databases/(default)/documents/X/a/Y",
72		Query: Query{
73			c:            testClient,
74			collectionID: "Y",
75			parentPath:   parentPath,
76			path:         "projects/projectID/databases/(default)/documents/X/a/Y",
77		},
78	}
79	if !testEqual(coll2, wantc2) {
80		t.Fatalf("\ngot  %+v\nwant %+v", coll2, wantc2)
81	}
82	doc2 := testClient.Doc("X/a/Y/b")
83	wantd2 := &DocumentRef{
84		Parent:    coll2,
85		ID:        "b",
86		Path:      "projects/projectID/databases/(default)/documents/X/a/Y/b",
87		shortPath: "X/a/Y/b",
88	}
89	if !testEqual(doc2, wantd2) {
90		t.Fatalf("got %+v, want %+v", doc2, wantd2)
91	}
92}
93
94func TestClientCollDocErrors(t *testing.T) {
95	for _, badColl := range []string{"", "/", "/a/", "/a/b", "a/b/", "a//b"} {
96		coll := testClient.Collection(badColl)
97		if coll != nil {
98			t.Errorf("coll path %q: got %+v, want nil", badColl, coll)
99		}
100	}
101	for _, badDoc := range []string{"", "a", "/", "/a", "a/", "a/b/c", "a//b/c"} {
102		doc := testClient.Doc(badDoc)
103		if doc != nil {
104			t.Errorf("doc path %q: got %+v, want nil", badDoc, doc)
105		}
106	}
107}
108
109func TestGetAll(t *testing.T) {
110	c, srv, cleanup := newMock(t)
111	defer cleanup()
112
113	const dbPath = "projects/projectID/databases/(default)"
114	req := &pb.BatchGetDocumentsRequest{
115		Database: dbPath,
116		Documents: []string{
117			dbPath + "/documents/C/a",
118			dbPath + "/documents/C/b",
119			dbPath + "/documents/C/c",
120		},
121	}
122	testGetAll(t, c, srv, dbPath, func(drs []*DocumentRef) ([]*DocumentSnapshot, error) {
123		return c.GetAll(context.Background(), drs)
124	}, req)
125}
126
127func testGetAll(t *testing.T, c *Client, srv *mockServer, dbPath string, getAll func([]*DocumentRef) ([]*DocumentSnapshot, error), req *pb.BatchGetDocumentsRequest) {
128	wantPBDocs := []*pb.Document{
129		{
130			Name:       dbPath + "/documents/C/a",
131			CreateTime: aTimestamp,
132			UpdateTime: aTimestamp,
133			Fields:     map[string]*pb.Value{"f": intval(2)},
134		},
135		nil,
136		{
137			Name:       dbPath + "/documents/C/c",
138			CreateTime: aTimestamp,
139			UpdateTime: aTimestamp,
140			Fields:     map[string]*pb.Value{"f": intval(1)},
141		},
142	}
143	wantReadTimes := []*tspb.Timestamp{aTimestamp, aTimestamp2, aTimestamp3}
144	srv.addRPC(req,
145		[]interface{}{
146			// deliberately put these out of order
147			&pb.BatchGetDocumentsResponse{
148				Result:   &pb.BatchGetDocumentsResponse_Found{wantPBDocs[2]},
149				ReadTime: aTimestamp3,
150			},
151			&pb.BatchGetDocumentsResponse{
152				Result:   &pb.BatchGetDocumentsResponse_Found{wantPBDocs[0]},
153				ReadTime: aTimestamp,
154			},
155			&pb.BatchGetDocumentsResponse{
156				Result:   &pb.BatchGetDocumentsResponse_Missing{dbPath + "/documents/C/b"},
157				ReadTime: aTimestamp2,
158			},
159		},
160	)
161	coll := c.Collection("C")
162	var docRefs []*DocumentRef
163	for _, name := range []string{"a", "b", "c"} {
164		docRefs = append(docRefs, coll.Doc(name))
165	}
166	docs, err := getAll(docRefs)
167	if err != nil {
168		t.Fatal(err)
169	}
170	if got, want := len(docs), len(wantPBDocs); got != want {
171		t.Errorf("got %d docs, wanted %d", got, want)
172	}
173	for i, got := range docs {
174		want, err := newDocumentSnapshot(docRefs[i], wantPBDocs[i], c, wantReadTimes[i])
175		if err != nil {
176			t.Fatal(err)
177		}
178		if diff := testDiff(got, want); diff != "" {
179			t.Errorf("#%d: got=--, want==++\n%s", i, diff)
180		}
181	}
182}
183
184func TestGetAllWithEqualRefs(t *testing.T) {
185	c, srv, cleanup := newMock(t)
186	defer cleanup()
187
188	const dbPath = "projects/projectID/databases/(default)"
189	req := &pb.BatchGetDocumentsRequest{
190		Database: dbPath,
191		Documents: []string{
192			dbPath + "/documents/C/a",
193			dbPath + "/documents/C/a",
194			dbPath + "/documents/C/c",
195			dbPath + "/documents/C/a",
196			dbPath + "/documents/C/b",
197			dbPath + "/documents/C/c",
198			dbPath + "/documents/C/b",
199		},
200	}
201	testGetAllWithEqualRefs(t, c, srv, dbPath, func(drs []*DocumentRef) ([]*DocumentSnapshot, error) {
202		return c.GetAll(context.Background(), drs)
203	}, req)
204}
205
206func testGetAllWithEqualRefs(t *testing.T, c *Client, srv *mockServer, dbPath string, getAll func([]*DocumentRef) ([]*DocumentSnapshot, error), req *pb.BatchGetDocumentsRequest) {
207	wantPBDocs := []*pb.Document{
208		{
209			Name:       dbPath + "/documents/C/a",
210			CreateTime: aTimestamp,
211			UpdateTime: aTimestamp,
212			Fields:     map[string]*pb.Value{"f": intval(2)},
213		},
214		{
215			Name:       dbPath + "/documents/C/c",
216			CreateTime: aTimestamp,
217			UpdateTime: aTimestamp,
218			Fields:     map[string]*pb.Value{"f": intval(1)},
219		},
220		nil,
221	}
222	srv.addRPC(req,
223		[]interface{}{
224			// deliberately put these out of order
225			&pb.BatchGetDocumentsResponse{
226				Result:   &pb.BatchGetDocumentsResponse_Found{wantPBDocs[1]},
227				ReadTime: aTimestamp3,
228			},
229			&pb.BatchGetDocumentsResponse{
230				Result:   &pb.BatchGetDocumentsResponse_Found{wantPBDocs[0]},
231				ReadTime: aTimestamp,
232			},
233			&pb.BatchGetDocumentsResponse{
234				Result:   &pb.BatchGetDocumentsResponse_Missing{dbPath + "/documents/C/b"},
235				ReadTime: aTimestamp2,
236			},
237		},
238	)
239	coll := c.Collection("C")
240	var docRefs []*DocumentRef
241	for _, name := range []string{"a", "a", "c", "a", "b", "c", "b"} {
242		docRefs = append(docRefs, coll.Doc(name))
243	}
244	// GetAll should return the same number of document snapshots as the
245	// number of document references in the input range, even when that means
246	// that the same document snapshot is referenced multiple times in the
247	// returned collection.
248	docs, err := getAll(docRefs)
249	if err != nil {
250		t.Fatal(err)
251	}
252	wantDocsIndices := []int{0, 0, 1, 0, 2, 1, 2}
253	wantReadTimes := []*tspb.Timestamp{aTimestamp, aTimestamp, aTimestamp3, aTimestamp, aTimestamp2, aTimestamp3, aTimestamp2}
254	if got, want := len(docs), len(wantDocsIndices); got != want {
255		t.Errorf("got %d docs, wanted %d", got, want)
256	}
257	for i, got := range docs {
258		want, err := newDocumentSnapshot(docRefs[i], wantPBDocs[wantDocsIndices[i]], c, wantReadTimes[i])
259		if err != nil {
260			t.Fatal(err)
261		}
262		if diff := testDiff(got, want); diff != "" {
263			t.Errorf("#%d: got=--, want==++\n%s", i, diff)
264		}
265	}
266}
267
268func TestGetAllErrors(t *testing.T) {
269	ctx := context.Background()
270	c, srv, cleanup := newMock(t)
271	defer cleanup()
272
273	const dbPath = "projects/projectID/databases/(default)"
274	const docPath = dbPath + "/documents/C/a"
275	if _, err := c.GetAll(ctx, []*DocumentRef{nil}); err != errNilDocRef {
276		t.Errorf("got %v, want errNilDocRef", err)
277	}
278
279	// Internal server error.
280	srv.addRPC(
281		&pb.BatchGetDocumentsRequest{
282			Database:  dbPath,
283			Documents: []string{docPath},
284		},
285		[]interface{}{status.Errorf(codes.Internal, "")},
286	)
287	_, err := c.GetAll(ctx, []*DocumentRef{c.Doc("C/a")})
288	codeEq(t, "GetAll #1", codes.Internal, err)
289
290	// Doc appears as both found and missing (server bug).
291	srv.reset()
292	srv.addRPC(
293		&pb.BatchGetDocumentsRequest{
294			Database:  dbPath,
295			Documents: []string{docPath},
296		},
297		[]interface{}{
298			&pb.BatchGetDocumentsResponse{
299				Result:   &pb.BatchGetDocumentsResponse_Found{&pb.Document{Name: docPath}},
300				ReadTime: aTimestamp,
301			},
302			&pb.BatchGetDocumentsResponse{
303				Result:   &pb.BatchGetDocumentsResponse_Missing{docPath},
304				ReadTime: aTimestamp,
305			},
306		},
307	)
308	if _, err := c.GetAll(ctx, []*DocumentRef{c.Doc("C/a")}); err == nil {
309		t.Error("got nil, want error")
310	}
311}
312