1// Copyright 2017 Google LLC
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//      http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15package firestore
16
17import (
18	"context"
19	"testing"
20
21	"github.com/golang/protobuf/ptypes/empty"
22	"google.golang.org/api/iterator"
23	pb "google.golang.org/genproto/googleapis/firestore/v1"
24	"google.golang.org/grpc/codes"
25	"google.golang.org/grpc/status"
26)
27
28func TestRunTransaction(t *testing.T) {
29	ctx := context.Background()
30	c, srv, cleanup := newMock(t)
31	defer cleanup()
32
33	const db = "projects/projectID/databases/(default)"
34	tid := []byte{1}
35
36	beginReq := &pb.BeginTransactionRequest{Database: db}
37	beginRes := &pb.BeginTransactionResponse{Transaction: tid}
38	commitReq := &pb.CommitRequest{Database: db, Transaction: tid}
39	// Empty transaction.
40	srv.addRPC(beginReq, beginRes)
41	srv.addRPC(commitReq, &pb.CommitResponse{CommitTime: aTimestamp})
42	err := c.RunTransaction(ctx, func(context.Context, *Transaction) error { return nil })
43	if err != nil {
44		t.Fatal(err)
45	}
46
47	// Transaction with read and write.
48	srv.reset()
49	srv.addRPC(beginReq, beginRes)
50	aDoc := &pb.Document{
51		Name:       db + "/documents/C/a",
52		CreateTime: aTimestamp,
53		UpdateTime: aTimestamp2,
54		Fields:     map[string]*pb.Value{"count": intval(1)},
55	}
56	srv.addRPC(
57		&pb.BatchGetDocumentsRequest{
58			Database:            c.path(),
59			Documents:           []string{db + "/documents/C/a"},
60			ConsistencySelector: &pb.BatchGetDocumentsRequest_Transaction{tid},
61		}, []interface{}{
62			&pb.BatchGetDocumentsResponse{
63				Result:   &pb.BatchGetDocumentsResponse_Found{aDoc},
64				ReadTime: aTimestamp2,
65			},
66		})
67	aDoc2 := &pb.Document{
68		Name:   aDoc.Name,
69		Fields: map[string]*pb.Value{"count": intval(2)},
70	}
71	srv.addRPC(
72		&pb.CommitRequest{
73			Database:    db,
74			Transaction: tid,
75			Writes: []*pb.Write{{
76				Operation:  &pb.Write_Update{aDoc2},
77				UpdateMask: &pb.DocumentMask{FieldPaths: []string{"count"}},
78				CurrentDocument: &pb.Precondition{
79					ConditionType: &pb.Precondition_Exists{true},
80				},
81			}},
82		},
83		&pb.CommitResponse{CommitTime: aTimestamp3},
84	)
85	err = c.RunTransaction(ctx, func(_ context.Context, tx *Transaction) error {
86		docref := c.Collection("C").Doc("a")
87		doc, err := tx.Get(docref)
88		if err != nil {
89			return err
90		}
91		count, err := doc.DataAt("count")
92		if err != nil {
93			return err
94		}
95		return tx.Update(docref, []Update{{Path: "count", Value: count.(int64) + 1}})
96	})
97	if err != nil {
98		t.Fatal(err)
99	}
100
101	// Query
102	srv.reset()
103	srv.addRPC(beginReq, beginRes)
104	srv.addRPC(
105		&pb.RunQueryRequest{
106			Parent: db + "/documents",
107			QueryType: &pb.RunQueryRequest_StructuredQuery{
108				&pb.StructuredQuery{
109					From: []*pb.StructuredQuery_CollectionSelector{{CollectionId: "C"}},
110				},
111			},
112			ConsistencySelector: &pb.RunQueryRequest_Transaction{tid},
113		},
114		[]interface{}{},
115	)
116	srv.addRPC(commitReq, &pb.CommitResponse{CommitTime: aTimestamp3})
117	err = c.RunTransaction(ctx, func(_ context.Context, tx *Transaction) error {
118		it := tx.Documents(c.Collection("C"))
119		defer it.Stop()
120		_, err := it.Next()
121		if err != iterator.Done {
122			return err
123		}
124		return nil
125	})
126	if err != nil {
127		t.Fatal(err)
128	}
129
130	// Retry entire transaction.
131	srv.reset()
132	srv.addRPC(beginReq, beginRes)
133	srv.addRPC(commitReq, status.Errorf(codes.Aborted, ""))
134	srv.addRPC(
135		&pb.BeginTransactionRequest{
136			Database: db,
137			Options: &pb.TransactionOptions{
138				Mode: &pb.TransactionOptions_ReadWrite_{
139					&pb.TransactionOptions_ReadWrite{RetryTransaction: tid},
140				},
141			},
142		},
143		beginRes,
144	)
145	srv.addRPC(commitReq, &pb.CommitResponse{CommitTime: aTimestamp})
146	err = c.RunTransaction(ctx, func(_ context.Context, tx *Transaction) error { return nil })
147	if err != nil {
148		t.Fatal(err)
149	}
150}
151
152func TestTransactionErrors(t *testing.T) {
153	t.Skip("https://github.com/googleapis/google-cloud-go/issues/1708")
154	ctx := context.Background()
155	const db = "projects/projectID/databases/(default)"
156	c, srv, cleanup := newMock(t)
157	defer cleanup()
158
159	var (
160		tid        = []byte{1}
161		unknownErr = status.Errorf(codes.Unknown, "so sad")
162		beginReq   = &pb.BeginTransactionRequest{
163			Database: db,
164		}
165		beginRes = &pb.BeginTransactionResponse{Transaction: tid}
166		getReq   = &pb.BatchGetDocumentsRequest{
167			Database:            c.path(),
168			Documents:           []string{db + "/documents/C/a"},
169			ConsistencySelector: &pb.BatchGetDocumentsRequest_Transaction{tid},
170		}
171		rollbackReq = &pb.RollbackRequest{Database: db, Transaction: tid}
172		commitReq   = &pb.CommitRequest{Database: db, Transaction: tid}
173	)
174
175	// BeginTransaction has a permanent error.
176	srv.addRPC(beginReq, unknownErr)
177	err := c.RunTransaction(ctx, func(context.Context, *Transaction) error { return nil })
178	if status.Code(err) != codes.Unknown {
179		t.Errorf("got <%v>, want Unknown", err)
180	}
181
182	// Get has a permanent error.
183	get := func(_ context.Context, tx *Transaction) error {
184		_, err := tx.Get(c.Doc("C/a"))
185		return err
186	}
187	srv.reset()
188	srv.addRPC(beginReq, beginRes)
189	srv.addRPC(getReq, unknownErr)
190	srv.addRPC(rollbackReq, &empty.Empty{})
191	err = c.RunTransaction(ctx, get)
192	if status.Code(err) != codes.Unknown {
193		t.Errorf("got <%v>, want Unknown", err)
194	}
195
196	// Get has a permanent error, but the rollback fails. We still
197	// return Get's error.
198	srv.reset()
199	srv.addRPC(beginReq, beginRes)
200	srv.addRPC(getReq, unknownErr)
201	srv.addRPC(rollbackReq, status.Errorf(codes.FailedPrecondition, ""))
202	err = c.RunTransaction(ctx, get)
203	if status.Code(err) != codes.Unknown {
204		t.Errorf("got <%v>, want Unknown", err)
205	}
206
207	// Commit has a permanent error.
208	srv.reset()
209	srv.addRPC(beginReq, beginRes)
210	srv.addRPC(getReq, []interface{}{
211		&pb.BatchGetDocumentsResponse{
212			Result: &pb.BatchGetDocumentsResponse_Found{&pb.Document{
213				Name:       "projects/projectID/databases/(default)/documents/C/a",
214				CreateTime: aTimestamp,
215				UpdateTime: aTimestamp2,
216			}},
217			ReadTime: aTimestamp2,
218		},
219	})
220	srv.addRPC(commitReq, unknownErr)
221	err = c.RunTransaction(ctx, get)
222	if status.Code(err) != codes.Unknown {
223		t.Errorf("got <%v>, want Unknown", err)
224	}
225
226	// Read after write.
227	srv.reset()
228	srv.addRPC(beginReq, beginRes)
229	srv.addRPC(rollbackReq, &empty.Empty{})
230	err = c.RunTransaction(ctx, func(_ context.Context, tx *Transaction) error {
231		if err := tx.Delete(c.Doc("C/a")); err != nil {
232			return err
233		}
234		if _, err := tx.Get(c.Doc("C/a")); err != nil {
235			return err
236		}
237		return nil
238	})
239	if err != errReadAfterWrite {
240		t.Errorf("got <%v>, want <%v>", err, errReadAfterWrite)
241	}
242
243	// Read after write, with query.
244	srv.reset()
245	srv.addRPC(beginReq, beginRes)
246	srv.addRPC(rollbackReq, &empty.Empty{})
247	err = c.RunTransaction(ctx, func(_ context.Context, tx *Transaction) error {
248		if err := tx.Delete(c.Doc("C/a")); err != nil {
249			return err
250		}
251		it := tx.Documents(c.Collection("C").Select("x"))
252		defer it.Stop()
253		if _, err := it.Next(); err != iterator.Done {
254			return err
255		}
256		return nil
257	})
258	if err != errReadAfterWrite {
259		t.Errorf("got <%v>, want <%v>", err, errReadAfterWrite)
260	}
261
262	// Read after write fails even if the user ignores the read's error.
263	srv.reset()
264	srv.addRPC(beginReq, beginRes)
265	srv.addRPC(rollbackReq, &empty.Empty{})
266	err = c.RunTransaction(ctx, func(_ context.Context, tx *Transaction) error {
267		if err := tx.Delete(c.Doc("C/a")); err != nil {
268			return err
269		}
270		if _, err := tx.Get(c.Doc("C/a")); err != nil {
271			return err
272		}
273		return nil
274	})
275	if err != errReadAfterWrite {
276		t.Errorf("got <%v>, want <%v>", err, errReadAfterWrite)
277	}
278
279	// Write in read-only transaction.
280	srv.reset()
281	srv.addRPC(
282		&pb.BeginTransactionRequest{
283			Database: db,
284			Options: &pb.TransactionOptions{
285				Mode: &pb.TransactionOptions_ReadOnly_{&pb.TransactionOptions_ReadOnly{}},
286			},
287		},
288		beginRes,
289	)
290	srv.addRPC(rollbackReq, &empty.Empty{})
291	err = c.RunTransaction(ctx, func(_ context.Context, tx *Transaction) error {
292		return tx.Delete(c.Doc("C/a"))
293	}, ReadOnly)
294	if err != errWriteReadOnly {
295		t.Errorf("got <%v>, want <%v>", err, errWriteReadOnly)
296	}
297
298	// Too many retries.
299	srv.reset()
300	srv.addRPC(beginReq, beginRes)
301	srv.addRPC(commitReq, status.Errorf(codes.Aborted, ""))
302	srv.addRPC(
303		&pb.BeginTransactionRequest{
304			Database: db,
305			Options: &pb.TransactionOptions{
306				Mode: &pb.TransactionOptions_ReadWrite_{
307					&pb.TransactionOptions_ReadWrite{RetryTransaction: tid},
308				},
309			},
310		},
311		beginRes,
312	)
313	srv.addRPC(commitReq, status.Errorf(codes.Aborted, ""))
314	srv.addRPC(rollbackReq, &empty.Empty{})
315	err = c.RunTransaction(ctx, func(context.Context, *Transaction) error { return nil },
316		MaxAttempts(2))
317	if status.Code(err) != codes.Aborted {
318		t.Errorf("got <%v>, want Aborted", err)
319	}
320
321	// Nested transaction.
322	srv.reset()
323	srv.addRPC(beginReq, beginRes)
324	srv.addRPC(rollbackReq, &empty.Empty{})
325	err = c.RunTransaction(ctx, func(ctx context.Context, tx *Transaction) error {
326		return c.RunTransaction(ctx, func(context.Context, *Transaction) error { return nil })
327	})
328	if got, want := err, errNestedTransaction; got != want {
329		t.Errorf("got <%v>, want <%v>", got, want)
330	}
331}
332
333func TestTransactionGetAll(t *testing.T) {
334	c, srv, cleanup := newMock(t)
335	defer cleanup()
336
337	const dbPath = "projects/projectID/databases/(default)"
338	tid := []byte{1}
339	beginReq := &pb.BeginTransactionRequest{Database: dbPath}
340	beginRes := &pb.BeginTransactionResponse{Transaction: tid}
341	srv.addRPC(beginReq, beginRes)
342	req := &pb.BatchGetDocumentsRequest{
343		Database: dbPath,
344		Documents: []string{
345			dbPath + "/documents/C/a",
346			dbPath + "/documents/C/b",
347			dbPath + "/documents/C/c",
348		},
349		ConsistencySelector: &pb.BatchGetDocumentsRequest_Transaction{tid},
350	}
351	err := c.RunTransaction(context.Background(), func(_ context.Context, tx *Transaction) error {
352		testGetAll(t, c, srv, dbPath,
353			func(drs []*DocumentRef) ([]*DocumentSnapshot, error) { return tx.GetAll(drs) },
354			req)
355		commitReq := &pb.CommitRequest{Database: dbPath, Transaction: tid}
356		srv.addRPC(commitReq, &pb.CommitResponse{CommitTime: aTimestamp})
357		return nil
358	})
359	if err != nil {
360		t.Fatal(err)
361	}
362}
363
364// Each retry attempt has the same amount of commit writes.
365func TestRunTransaction_Retries(t *testing.T) {
366	ctx := context.Background()
367	c, srv, cleanup := newMock(t)
368	defer cleanup()
369
370	const db = "projects/projectID/databases/(default)"
371	tid := []byte{1}
372
373	srv.addRPC(
374		&pb.BeginTransactionRequest{Database: db},
375		&pb.BeginTransactionResponse{Transaction: tid},
376	)
377
378	aDoc := &pb.Document{
379		Name:       db + "/documents/C/a",
380		CreateTime: aTimestamp,
381		UpdateTime: aTimestamp2,
382		Fields:     map[string]*pb.Value{"count": intval(1)},
383	}
384	aDoc2 := &pb.Document{
385		Name:   aDoc.Name,
386		Fields: map[string]*pb.Value{"count": intval(7)},
387	}
388
389	srv.addRPC(
390		&pb.CommitRequest{
391			Database:    db,
392			Transaction: tid,
393			Writes: []*pb.Write{{
394				Operation:  &pb.Write_Update{aDoc2},
395				UpdateMask: &pb.DocumentMask{FieldPaths: []string{"count"}},
396				CurrentDocument: &pb.Precondition{
397					ConditionType: &pb.Precondition_Exists{true},
398				},
399			}},
400		},
401		status.Errorf(codes.Aborted, "something failed! please retry me!"),
402	)
403
404	srv.addRPC(
405		&pb.BeginTransactionRequest{
406			Database: db,
407			Options: &pb.TransactionOptions{
408				Mode: &pb.TransactionOptions_ReadWrite_{
409					&pb.TransactionOptions_ReadWrite{RetryTransaction: tid},
410				},
411			},
412		},
413		&pb.BeginTransactionResponse{Transaction: tid},
414	)
415
416	srv.addRPC(
417		&pb.CommitRequest{
418			Database:    db,
419			Transaction: tid,
420			Writes: []*pb.Write{{
421				Operation:  &pb.Write_Update{aDoc2},
422				UpdateMask: &pb.DocumentMask{FieldPaths: []string{"count"}},
423				CurrentDocument: &pb.Precondition{
424					ConditionType: &pb.Precondition_Exists{true},
425				},
426			}},
427		},
428		&pb.CommitResponse{CommitTime: aTimestamp3},
429	)
430
431	err := c.RunTransaction(ctx, func(_ context.Context, tx *Transaction) error {
432		docref := c.Collection("C").Doc("a")
433		return tx.Update(docref, []Update{{Path: "count", Value: 7}})
434	})
435	if err != nil {
436		t.Fatal(err)
437	}
438}
439
440// Non-transactional operations are allowed in transactions (although
441// discouraged).
442func TestRunTransaction_NonTransactionalOp(t *testing.T) {
443	ctx := context.Background()
444	c, srv, cleanup := newMock(t)
445	defer cleanup()
446
447	const db = "projects/projectID/databases/(default)"
448	tid := []byte{1}
449
450	beginReq := &pb.BeginTransactionRequest{Database: db}
451	beginRes := &pb.BeginTransactionResponse{Transaction: tid}
452
453	srv.reset()
454	srv.addRPC(beginReq, beginRes)
455	aDoc := &pb.Document{
456		Name:       db + "/documents/C/a",
457		CreateTime: aTimestamp,
458		UpdateTime: aTimestamp2,
459		Fields:     map[string]*pb.Value{"count": intval(1)},
460	}
461	srv.addRPC(
462		&pb.BatchGetDocumentsRequest{
463			Database:  c.path(),
464			Documents: []string{db + "/documents/C/a"},
465		}, []interface{}{
466			&pb.BatchGetDocumentsResponse{
467				Result:   &pb.BatchGetDocumentsResponse_Found{aDoc},
468				ReadTime: aTimestamp2,
469			},
470		})
471	srv.addRPC(
472		&pb.CommitRequest{
473			Database:    db,
474			Transaction: tid,
475		},
476		&pb.CommitResponse{CommitTime: aTimestamp3},
477	)
478
479	if err := c.RunTransaction(ctx, func(ctx2 context.Context, tx *Transaction) error {
480		docref := c.Collection("C").Doc("a")
481		if _, err := c.GetAll(ctx2, []*DocumentRef{docref}); err != nil {
482			t.Fatal(err)
483		}
484		return nil
485	}); err != nil {
486		t.Fatal(err)
487	}
488}
489