1// Copyright 2017 Google LLC 2// 3// Licensed under the Apache License, Version 2.0 (the "License"); 4// you may not use this file except in compliance with the License. 5// You may obtain a copy of the License at 6// 7// http://www.apache.org/licenses/LICENSE-2.0 8// 9// Unless required by applicable law or agreed to in writing, software 10// distributed under the License is distributed on an "AS IS" BASIS, 11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12// See the License for the specific language governing permissions and 13// limitations under the License. 14 15package firestore 16 17import ( 18 "context" 19 "testing" 20 21 "github.com/golang/protobuf/ptypes/empty" 22 "google.golang.org/api/iterator" 23 pb "google.golang.org/genproto/googleapis/firestore/v1" 24 "google.golang.org/grpc/codes" 25 "google.golang.org/grpc/status" 26) 27 28func TestRunTransaction(t *testing.T) { 29 ctx := context.Background() 30 c, srv, cleanup := newMock(t) 31 defer cleanup() 32 33 const db = "projects/projectID/databases/(default)" 34 tid := []byte{1} 35 36 beginReq := &pb.BeginTransactionRequest{Database: db} 37 beginRes := &pb.BeginTransactionResponse{Transaction: tid} 38 commitReq := &pb.CommitRequest{Database: db, Transaction: tid} 39 // Empty transaction. 40 srv.addRPC(beginReq, beginRes) 41 srv.addRPC(commitReq, &pb.CommitResponse{CommitTime: aTimestamp}) 42 err := c.RunTransaction(ctx, func(context.Context, *Transaction) error { return nil }) 43 if err != nil { 44 t.Fatal(err) 45 } 46 47 // Transaction with read and write. 48 srv.reset() 49 srv.addRPC(beginReq, beginRes) 50 aDoc := &pb.Document{ 51 Name: db + "/documents/C/a", 52 CreateTime: aTimestamp, 53 UpdateTime: aTimestamp2, 54 Fields: map[string]*pb.Value{"count": intval(1)}, 55 } 56 srv.addRPC( 57 &pb.BatchGetDocumentsRequest{ 58 Database: c.path(), 59 Documents: []string{db + "/documents/C/a"}, 60 ConsistencySelector: &pb.BatchGetDocumentsRequest_Transaction{tid}, 61 }, []interface{}{ 62 &pb.BatchGetDocumentsResponse{ 63 Result: &pb.BatchGetDocumentsResponse_Found{aDoc}, 64 ReadTime: aTimestamp2, 65 }, 66 }) 67 aDoc2 := &pb.Document{ 68 Name: aDoc.Name, 69 Fields: map[string]*pb.Value{"count": intval(2)}, 70 } 71 srv.addRPC( 72 &pb.CommitRequest{ 73 Database: db, 74 Transaction: tid, 75 Writes: []*pb.Write{{ 76 Operation: &pb.Write_Update{aDoc2}, 77 UpdateMask: &pb.DocumentMask{FieldPaths: []string{"count"}}, 78 CurrentDocument: &pb.Precondition{ 79 ConditionType: &pb.Precondition_Exists{true}, 80 }, 81 }}, 82 }, 83 &pb.CommitResponse{CommitTime: aTimestamp3}, 84 ) 85 err = c.RunTransaction(ctx, func(_ context.Context, tx *Transaction) error { 86 docref := c.Collection("C").Doc("a") 87 doc, err := tx.Get(docref) 88 if err != nil { 89 return err 90 } 91 count, err := doc.DataAt("count") 92 if err != nil { 93 return err 94 } 95 return tx.Update(docref, []Update{{Path: "count", Value: count.(int64) + 1}}) 96 }) 97 if err != nil { 98 t.Fatal(err) 99 } 100 101 // Query 102 srv.reset() 103 srv.addRPC(beginReq, beginRes) 104 srv.addRPC( 105 &pb.RunQueryRequest{ 106 Parent: db + "/documents", 107 QueryType: &pb.RunQueryRequest_StructuredQuery{ 108 &pb.StructuredQuery{ 109 From: []*pb.StructuredQuery_CollectionSelector{{CollectionId: "C"}}, 110 }, 111 }, 112 ConsistencySelector: &pb.RunQueryRequest_Transaction{tid}, 113 }, 114 []interface{}{}, 115 ) 116 srv.addRPC(commitReq, &pb.CommitResponse{CommitTime: aTimestamp3}) 117 err = c.RunTransaction(ctx, func(_ context.Context, tx *Transaction) error { 118 it := tx.Documents(c.Collection("C")) 119 defer it.Stop() 120 _, err := it.Next() 121 if err != iterator.Done { 122 return err 123 } 124 return nil 125 }) 126 if err != nil { 127 t.Fatal(err) 128 } 129 130 // Retry entire transaction. 131 srv.reset() 132 srv.addRPC(beginReq, beginRes) 133 srv.addRPC(commitReq, status.Errorf(codes.Aborted, "")) 134 srv.addRPC( 135 &pb.BeginTransactionRequest{ 136 Database: db, 137 Options: &pb.TransactionOptions{ 138 Mode: &pb.TransactionOptions_ReadWrite_{ 139 &pb.TransactionOptions_ReadWrite{RetryTransaction: tid}, 140 }, 141 }, 142 }, 143 beginRes, 144 ) 145 srv.addRPC(commitReq, &pb.CommitResponse{CommitTime: aTimestamp}) 146 err = c.RunTransaction(ctx, func(_ context.Context, tx *Transaction) error { return nil }) 147 if err != nil { 148 t.Fatal(err) 149 } 150} 151 152func TestTransactionErrors(t *testing.T) { 153 t.Skip("https://github.com/googleapis/google-cloud-go/issues/1708") 154 ctx := context.Background() 155 const db = "projects/projectID/databases/(default)" 156 c, srv, cleanup := newMock(t) 157 defer cleanup() 158 159 var ( 160 tid = []byte{1} 161 unknownErr = status.Errorf(codes.Unknown, "so sad") 162 beginReq = &pb.BeginTransactionRequest{ 163 Database: db, 164 } 165 beginRes = &pb.BeginTransactionResponse{Transaction: tid} 166 getReq = &pb.BatchGetDocumentsRequest{ 167 Database: c.path(), 168 Documents: []string{db + "/documents/C/a"}, 169 ConsistencySelector: &pb.BatchGetDocumentsRequest_Transaction{tid}, 170 } 171 rollbackReq = &pb.RollbackRequest{Database: db, Transaction: tid} 172 commitReq = &pb.CommitRequest{Database: db, Transaction: tid} 173 ) 174 175 // BeginTransaction has a permanent error. 176 srv.addRPC(beginReq, unknownErr) 177 err := c.RunTransaction(ctx, func(context.Context, *Transaction) error { return nil }) 178 if status.Code(err) != codes.Unknown { 179 t.Errorf("got <%v>, want Unknown", err) 180 } 181 182 // Get has a permanent error. 183 get := func(_ context.Context, tx *Transaction) error { 184 _, err := tx.Get(c.Doc("C/a")) 185 return err 186 } 187 srv.reset() 188 srv.addRPC(beginReq, beginRes) 189 srv.addRPC(getReq, unknownErr) 190 srv.addRPC(rollbackReq, &empty.Empty{}) 191 err = c.RunTransaction(ctx, get) 192 if status.Code(err) != codes.Unknown { 193 t.Errorf("got <%v>, want Unknown", err) 194 } 195 196 // Get has a permanent error, but the rollback fails. We still 197 // return Get's error. 198 srv.reset() 199 srv.addRPC(beginReq, beginRes) 200 srv.addRPC(getReq, unknownErr) 201 srv.addRPC(rollbackReq, status.Errorf(codes.FailedPrecondition, "")) 202 err = c.RunTransaction(ctx, get) 203 if status.Code(err) != codes.Unknown { 204 t.Errorf("got <%v>, want Unknown", err) 205 } 206 207 // Commit has a permanent error. 208 srv.reset() 209 srv.addRPC(beginReq, beginRes) 210 srv.addRPC(getReq, []interface{}{ 211 &pb.BatchGetDocumentsResponse{ 212 Result: &pb.BatchGetDocumentsResponse_Found{&pb.Document{ 213 Name: "projects/projectID/databases/(default)/documents/C/a", 214 CreateTime: aTimestamp, 215 UpdateTime: aTimestamp2, 216 }}, 217 ReadTime: aTimestamp2, 218 }, 219 }) 220 srv.addRPC(commitReq, unknownErr) 221 err = c.RunTransaction(ctx, get) 222 if status.Code(err) != codes.Unknown { 223 t.Errorf("got <%v>, want Unknown", err) 224 } 225 226 // Read after write. 227 srv.reset() 228 srv.addRPC(beginReq, beginRes) 229 srv.addRPC(rollbackReq, &empty.Empty{}) 230 err = c.RunTransaction(ctx, func(_ context.Context, tx *Transaction) error { 231 if err := tx.Delete(c.Doc("C/a")); err != nil { 232 return err 233 } 234 if _, err := tx.Get(c.Doc("C/a")); err != nil { 235 return err 236 } 237 return nil 238 }) 239 if err != errReadAfterWrite { 240 t.Errorf("got <%v>, want <%v>", err, errReadAfterWrite) 241 } 242 243 // Read after write, with query. 244 srv.reset() 245 srv.addRPC(beginReq, beginRes) 246 srv.addRPC(rollbackReq, &empty.Empty{}) 247 err = c.RunTransaction(ctx, func(_ context.Context, tx *Transaction) error { 248 if err := tx.Delete(c.Doc("C/a")); err != nil { 249 return err 250 } 251 it := tx.Documents(c.Collection("C").Select("x")) 252 defer it.Stop() 253 if _, err := it.Next(); err != iterator.Done { 254 return err 255 } 256 return nil 257 }) 258 if err != errReadAfterWrite { 259 t.Errorf("got <%v>, want <%v>", err, errReadAfterWrite) 260 } 261 262 // Read after write fails even if the user ignores the read's error. 263 srv.reset() 264 srv.addRPC(beginReq, beginRes) 265 srv.addRPC(rollbackReq, &empty.Empty{}) 266 err = c.RunTransaction(ctx, func(_ context.Context, tx *Transaction) error { 267 if err := tx.Delete(c.Doc("C/a")); err != nil { 268 return err 269 } 270 if _, err := tx.Get(c.Doc("C/a")); err != nil { 271 return err 272 } 273 return nil 274 }) 275 if err != errReadAfterWrite { 276 t.Errorf("got <%v>, want <%v>", err, errReadAfterWrite) 277 } 278 279 // Write in read-only transaction. 280 srv.reset() 281 srv.addRPC( 282 &pb.BeginTransactionRequest{ 283 Database: db, 284 Options: &pb.TransactionOptions{ 285 Mode: &pb.TransactionOptions_ReadOnly_{&pb.TransactionOptions_ReadOnly{}}, 286 }, 287 }, 288 beginRes, 289 ) 290 srv.addRPC(rollbackReq, &empty.Empty{}) 291 err = c.RunTransaction(ctx, func(_ context.Context, tx *Transaction) error { 292 return tx.Delete(c.Doc("C/a")) 293 }, ReadOnly) 294 if err != errWriteReadOnly { 295 t.Errorf("got <%v>, want <%v>", err, errWriteReadOnly) 296 } 297 298 // Too many retries. 299 srv.reset() 300 srv.addRPC(beginReq, beginRes) 301 srv.addRPC(commitReq, status.Errorf(codes.Aborted, "")) 302 srv.addRPC( 303 &pb.BeginTransactionRequest{ 304 Database: db, 305 Options: &pb.TransactionOptions{ 306 Mode: &pb.TransactionOptions_ReadWrite_{ 307 &pb.TransactionOptions_ReadWrite{RetryTransaction: tid}, 308 }, 309 }, 310 }, 311 beginRes, 312 ) 313 srv.addRPC(commitReq, status.Errorf(codes.Aborted, "")) 314 srv.addRPC(rollbackReq, &empty.Empty{}) 315 err = c.RunTransaction(ctx, func(context.Context, *Transaction) error { return nil }, 316 MaxAttempts(2)) 317 if status.Code(err) != codes.Aborted { 318 t.Errorf("got <%v>, want Aborted", err) 319 } 320 321 // Nested transaction. 322 srv.reset() 323 srv.addRPC(beginReq, beginRes) 324 srv.addRPC(rollbackReq, &empty.Empty{}) 325 err = c.RunTransaction(ctx, func(ctx context.Context, tx *Transaction) error { 326 return c.RunTransaction(ctx, func(context.Context, *Transaction) error { return nil }) 327 }) 328 if got, want := err, errNestedTransaction; got != want { 329 t.Errorf("got <%v>, want <%v>", got, want) 330 } 331} 332 333func TestTransactionGetAll(t *testing.T) { 334 c, srv, cleanup := newMock(t) 335 defer cleanup() 336 337 const dbPath = "projects/projectID/databases/(default)" 338 tid := []byte{1} 339 beginReq := &pb.BeginTransactionRequest{Database: dbPath} 340 beginRes := &pb.BeginTransactionResponse{Transaction: tid} 341 srv.addRPC(beginReq, beginRes) 342 req := &pb.BatchGetDocumentsRequest{ 343 Database: dbPath, 344 Documents: []string{ 345 dbPath + "/documents/C/a", 346 dbPath + "/documents/C/b", 347 dbPath + "/documents/C/c", 348 }, 349 ConsistencySelector: &pb.BatchGetDocumentsRequest_Transaction{tid}, 350 } 351 err := c.RunTransaction(context.Background(), func(_ context.Context, tx *Transaction) error { 352 testGetAll(t, c, srv, dbPath, 353 func(drs []*DocumentRef) ([]*DocumentSnapshot, error) { return tx.GetAll(drs) }, 354 req) 355 commitReq := &pb.CommitRequest{Database: dbPath, Transaction: tid} 356 srv.addRPC(commitReq, &pb.CommitResponse{CommitTime: aTimestamp}) 357 return nil 358 }) 359 if err != nil { 360 t.Fatal(err) 361 } 362} 363 364// Each retry attempt has the same amount of commit writes. 365func TestRunTransaction_Retries(t *testing.T) { 366 ctx := context.Background() 367 c, srv, cleanup := newMock(t) 368 defer cleanup() 369 370 const db = "projects/projectID/databases/(default)" 371 tid := []byte{1} 372 373 srv.addRPC( 374 &pb.BeginTransactionRequest{Database: db}, 375 &pb.BeginTransactionResponse{Transaction: tid}, 376 ) 377 378 aDoc := &pb.Document{ 379 Name: db + "/documents/C/a", 380 CreateTime: aTimestamp, 381 UpdateTime: aTimestamp2, 382 Fields: map[string]*pb.Value{"count": intval(1)}, 383 } 384 aDoc2 := &pb.Document{ 385 Name: aDoc.Name, 386 Fields: map[string]*pb.Value{"count": intval(7)}, 387 } 388 389 srv.addRPC( 390 &pb.CommitRequest{ 391 Database: db, 392 Transaction: tid, 393 Writes: []*pb.Write{{ 394 Operation: &pb.Write_Update{aDoc2}, 395 UpdateMask: &pb.DocumentMask{FieldPaths: []string{"count"}}, 396 CurrentDocument: &pb.Precondition{ 397 ConditionType: &pb.Precondition_Exists{true}, 398 }, 399 }}, 400 }, 401 status.Errorf(codes.Aborted, "something failed! please retry me!"), 402 ) 403 404 srv.addRPC( 405 &pb.BeginTransactionRequest{ 406 Database: db, 407 Options: &pb.TransactionOptions{ 408 Mode: &pb.TransactionOptions_ReadWrite_{ 409 &pb.TransactionOptions_ReadWrite{RetryTransaction: tid}, 410 }, 411 }, 412 }, 413 &pb.BeginTransactionResponse{Transaction: tid}, 414 ) 415 416 srv.addRPC( 417 &pb.CommitRequest{ 418 Database: db, 419 Transaction: tid, 420 Writes: []*pb.Write{{ 421 Operation: &pb.Write_Update{aDoc2}, 422 UpdateMask: &pb.DocumentMask{FieldPaths: []string{"count"}}, 423 CurrentDocument: &pb.Precondition{ 424 ConditionType: &pb.Precondition_Exists{true}, 425 }, 426 }}, 427 }, 428 &pb.CommitResponse{CommitTime: aTimestamp3}, 429 ) 430 431 err := c.RunTransaction(ctx, func(_ context.Context, tx *Transaction) error { 432 docref := c.Collection("C").Doc("a") 433 return tx.Update(docref, []Update{{Path: "count", Value: 7}}) 434 }) 435 if err != nil { 436 t.Fatal(err) 437 } 438} 439 440// Non-transactional operations are allowed in transactions (although 441// discouraged). 442func TestRunTransaction_NonTransactionalOp(t *testing.T) { 443 ctx := context.Background() 444 c, srv, cleanup := newMock(t) 445 defer cleanup() 446 447 const db = "projects/projectID/databases/(default)" 448 tid := []byte{1} 449 450 beginReq := &pb.BeginTransactionRequest{Database: db} 451 beginRes := &pb.BeginTransactionResponse{Transaction: tid} 452 453 srv.reset() 454 srv.addRPC(beginReq, beginRes) 455 aDoc := &pb.Document{ 456 Name: db + "/documents/C/a", 457 CreateTime: aTimestamp, 458 UpdateTime: aTimestamp2, 459 Fields: map[string]*pb.Value{"count": intval(1)}, 460 } 461 srv.addRPC( 462 &pb.BatchGetDocumentsRequest{ 463 Database: c.path(), 464 Documents: []string{db + "/documents/C/a"}, 465 }, []interface{}{ 466 &pb.BatchGetDocumentsResponse{ 467 Result: &pb.BatchGetDocumentsResponse_Found{aDoc}, 468 ReadTime: aTimestamp2, 469 }, 470 }) 471 srv.addRPC( 472 &pb.CommitRequest{ 473 Database: db, 474 Transaction: tid, 475 }, 476 &pb.CommitResponse{CommitTime: aTimestamp3}, 477 ) 478 479 if err := c.RunTransaction(ctx, func(ctx2 context.Context, tx *Transaction) error { 480 docref := c.Collection("C").Doc("a") 481 if _, err := c.GetAll(ctx2, []*DocumentRef{docref}); err != nil { 482 t.Fatal(err) 483 } 484 return nil 485 }); err != nil { 486 t.Fatal(err) 487 } 488} 489