1// Copyright 2017 Google LLC
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//      http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15package firestore
16
17import (
18	"context"
19	"testing"
20
21	tspb "github.com/golang/protobuf/ptypes/timestamp"
22	pb "google.golang.org/genproto/googleapis/firestore/v1"
23	"google.golang.org/grpc/codes"
24	"google.golang.org/grpc/status"
25)
26
27var testClient = &Client{
28	projectID:  "projectID",
29	databaseID: "(default)",
30}
31
32func TestClientCollectionAndDoc(t *testing.T) {
33	coll1 := testClient.Collection("X")
34	db := "projects/projectID/databases/(default)"
35	wantc1 := &CollectionRef{
36		c:          testClient,
37		parentPath: db + "/documents",
38		selfPath:   "X",
39		Parent:     nil,
40		ID:         "X",
41		Path:       "projects/projectID/databases/(default)/documents/X",
42		Query: Query{
43			c:            testClient,
44			collectionID: "X",
45			path:         "projects/projectID/databases/(default)/documents/X",
46			parentPath:   db + "/documents",
47		},
48	}
49	if !testEqual(coll1, wantc1) {
50		t.Fatalf("got\n%+v\nwant\n%+v", coll1, wantc1)
51	}
52	doc1 := testClient.Doc("X/a")
53	wantd1 := &DocumentRef{
54		Parent:    coll1,
55		ID:        "a",
56		Path:      "projects/projectID/databases/(default)/documents/X/a",
57		shortPath: "X/a",
58	}
59
60	if !testEqual(doc1, wantd1) {
61		t.Fatalf("got %+v, want %+v", doc1, wantd1)
62	}
63	coll2 := testClient.Collection("X/a/Y")
64	parentPath := "projects/projectID/databases/(default)/documents/X/a"
65	wantc2 := &CollectionRef{
66		c:          testClient,
67		parentPath: parentPath,
68		selfPath:   "X/a/Y",
69		Parent:     doc1,
70		ID:         "Y",
71		Path:       "projects/projectID/databases/(default)/documents/X/a/Y",
72		Query: Query{
73			c:            testClient,
74			collectionID: "Y",
75			parentPath:   parentPath,
76			path:         "projects/projectID/databases/(default)/documents/X/a/Y",
77		},
78	}
79	if !testEqual(coll2, wantc2) {
80		t.Fatalf("\ngot  %+v\nwant %+v", coll2, wantc2)
81	}
82	doc2 := testClient.Doc("X/a/Y/b")
83	wantd2 := &DocumentRef{
84		Parent:    coll2,
85		ID:        "b",
86		Path:      "projects/projectID/databases/(default)/documents/X/a/Y/b",
87		shortPath: "X/a/Y/b",
88	}
89	if !testEqual(doc2, wantd2) {
90		t.Fatalf("got %+v, want %+v", doc2, wantd2)
91	}
92}
93
94func TestClientCollDocErrors(t *testing.T) {
95	for _, badColl := range []string{"", "/", "/a/", "/a/b", "a/b/", "a//b"} {
96		coll := testClient.Collection(badColl)
97		if coll != nil {
98			t.Errorf("coll path %q: got %+v, want nil", badColl, coll)
99		}
100	}
101	for _, badDoc := range []string{"", "a", "/", "/a", "a/", "a/b/c", "a//b/c"} {
102		doc := testClient.Doc(badDoc)
103		if doc != nil {
104			t.Errorf("doc path %q: got %+v, want nil", badDoc, doc)
105		}
106	}
107}
108
109func TestGetAll(t *testing.T) {
110	c, srv := newMock(t)
111	defer c.Close()
112	const dbPath = "projects/projectID/databases/(default)"
113	req := &pb.BatchGetDocumentsRequest{
114		Database: dbPath,
115		Documents: []string{
116			dbPath + "/documents/C/a",
117			dbPath + "/documents/C/b",
118			dbPath + "/documents/C/c",
119		},
120	}
121	testGetAll(t, c, srv, dbPath, func(drs []*DocumentRef) ([]*DocumentSnapshot, error) {
122		return c.GetAll(context.Background(), drs)
123	}, req)
124}
125
126func testGetAll(t *testing.T, c *Client, srv *mockServer, dbPath string, getAll func([]*DocumentRef) ([]*DocumentSnapshot, error), req *pb.BatchGetDocumentsRequest) {
127	wantPBDocs := []*pb.Document{
128		{
129			Name:       dbPath + "/documents/C/a",
130			CreateTime: aTimestamp,
131			UpdateTime: aTimestamp,
132			Fields:     map[string]*pb.Value{"f": intval(2)},
133		},
134		nil,
135		{
136			Name:       dbPath + "/documents/C/c",
137			CreateTime: aTimestamp,
138			UpdateTime: aTimestamp,
139			Fields:     map[string]*pb.Value{"f": intval(1)},
140		},
141	}
142	wantReadTimes := []*tspb.Timestamp{aTimestamp, aTimestamp2, aTimestamp3}
143	srv.addRPC(req,
144		[]interface{}{
145			// deliberately put these out of order
146			&pb.BatchGetDocumentsResponse{
147				Result:   &pb.BatchGetDocumentsResponse_Found{wantPBDocs[2]},
148				ReadTime: aTimestamp3,
149			},
150			&pb.BatchGetDocumentsResponse{
151				Result:   &pb.BatchGetDocumentsResponse_Found{wantPBDocs[0]},
152				ReadTime: aTimestamp,
153			},
154			&pb.BatchGetDocumentsResponse{
155				Result:   &pb.BatchGetDocumentsResponse_Missing{dbPath + "/documents/C/b"},
156				ReadTime: aTimestamp2,
157			},
158		},
159	)
160	coll := c.Collection("C")
161	var docRefs []*DocumentRef
162	for _, name := range []string{"a", "b", "c"} {
163		docRefs = append(docRefs, coll.Doc(name))
164	}
165	docs, err := getAll(docRefs)
166	if err != nil {
167		t.Fatal(err)
168	}
169	if got, want := len(docs), len(wantPBDocs); got != want {
170		t.Errorf("got %d docs, wanted %d", got, want)
171	}
172	for i, got := range docs {
173		want, err := newDocumentSnapshot(docRefs[i], wantPBDocs[i], c, wantReadTimes[i])
174		if err != nil {
175			t.Fatal(err)
176		}
177		if diff := testDiff(got, want); diff != "" {
178			t.Errorf("#%d: got=--, want==++\n%s", i, diff)
179		}
180	}
181}
182
183func TestGetAllErrors(t *testing.T) {
184	ctx := context.Background()
185	const (
186		dbPath  = "projects/projectID/databases/(default)"
187		docPath = dbPath + "/documents/C/a"
188	)
189	c, srv := newMock(t)
190	if _, err := c.GetAll(ctx, []*DocumentRef{nil}); err != errNilDocRef {
191		t.Errorf("got %v, want errNilDocRef", err)
192	}
193
194	// Internal server error.
195	srv.addRPC(
196		&pb.BatchGetDocumentsRequest{
197			Database:  dbPath,
198			Documents: []string{docPath},
199		},
200		[]interface{}{status.Errorf(codes.Internal, "")},
201	)
202	_, err := c.GetAll(ctx, []*DocumentRef{c.Doc("C/a")})
203	codeEq(t, "GetAll #1", codes.Internal, err)
204
205	// Doc appears as both found and missing (server bug).
206	srv.reset()
207	srv.addRPC(
208		&pb.BatchGetDocumentsRequest{
209			Database:  dbPath,
210			Documents: []string{docPath},
211		},
212		[]interface{}{
213			&pb.BatchGetDocumentsResponse{
214				Result:   &pb.BatchGetDocumentsResponse_Found{&pb.Document{Name: docPath}},
215				ReadTime: aTimestamp,
216			},
217			&pb.BatchGetDocumentsResponse{
218				Result:   &pb.BatchGetDocumentsResponse_Missing{docPath},
219				ReadTime: aTimestamp,
220			},
221		},
222	)
223	if _, err := c.GetAll(ctx, []*DocumentRef{c.Doc("C/a")}); err == nil {
224		t.Error("got nil, want error")
225	}
226}
227