1// Copyright 2019 Google LLC 2// 3// Licensed under the Apache License, Version 2.0 (the "License"); 4// you may not use this file except in compliance with the License. 5// You may obtain a copy of the License at 6// 7// https://www.apache.org/licenses/LICENSE-2.0 8// 9// Unless required by applicable law or agreed to in writing, software 10// distributed under the License is distributed on an "AS IS" BASIS, 11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12// See the License for the specific language governing permissions and 13// limitations under the License. 14 15package testutil_test 16 17import ( 18 "strconv" 19 20 . "cloud.google.com/go/spanner/internal/testutil" 21 22 "context" 23 "flag" 24 "fmt" 25 "log" 26 "net" 27 "os" 28 "strings" 29 "testing" 30 31 structpb "github.com/golang/protobuf/ptypes/struct" 32 spannerpb "google.golang.org/genproto/googleapis/spanner/v1" 33 "google.golang.org/grpc/codes" 34 35 apiv1 "cloud.google.com/go/spanner/apiv1" 36 "google.golang.org/api/iterator" 37 "google.golang.org/api/option" 38 "google.golang.org/grpc" 39 40 gstatus "google.golang.org/grpc/status" 41) 42 43// clientOpt is the option tests should use to connect to the test server. 44// It is initialized by TestMain. 45var serverAddress string 46var clientOpt option.ClientOption 47var testSpanner InMemSpannerServer 48 49// Mocked selectSQL statement. 50const selectSQL = "SELECT FOO FROM BAR" 51const selectRowCount int64 = 2 52const selectColCount int = 1 53 54var selectValues = [...]int64{1, 2} 55 56// Mocked DML statement. 57const updateSQL = "UPDATE FOO SET BAR=1 WHERE ID=ID" 58const updateRowCount int64 = 2 59 60func TestMain(m *testing.M) { 61 flag.Parse() 62 63 testSpanner = NewInMemSpannerServer() 64 serv := grpc.NewServer() 65 spannerpb.RegisterSpannerServer(serv, testSpanner) 66 67 lis, err := net.Listen("tcp", "localhost:0") 68 if err != nil { 69 log.Fatal(err) 70 } 71 go serv.Serve(lis) 72 73 serverAddress = lis.Addr().String() 74 conn, err := grpc.Dial(serverAddress, grpc.WithInsecure()) 75 if err != nil { 76 log.Fatal(err) 77 } 78 clientOpt = option.WithGRPCConn(conn) 79 80 os.Exit(m.Run()) 81} 82 83// Resets the mock server to its default values and registers a mocked result 84// for the statements "SELECT FOO FROM BAR" and 85// "UPDATE FOO SET BAR=1 WHERE ID=ID". 86func setup() { 87 testSpanner.Reset() 88 fields := make([]*spannerpb.StructType_Field, selectColCount) 89 fields[0] = &spannerpb.StructType_Field{ 90 Name: "FOO", 91 Type: &spannerpb.Type{Code: spannerpb.TypeCode_INT64}, 92 } 93 rowType := &spannerpb.StructType{ 94 Fields: fields, 95 } 96 metadata := &spannerpb.ResultSetMetadata{ 97 RowType: rowType, 98 } 99 rows := make([]*structpb.ListValue, selectRowCount) 100 for idx, value := range selectValues { 101 rowValue := make([]*structpb.Value, selectColCount) 102 rowValue[0] = &structpb.Value{ 103 Kind: &structpb.Value_StringValue{StringValue: strconv.FormatInt(value, 10)}, 104 } 105 rows[idx] = &structpb.ListValue{ 106 Values: rowValue, 107 } 108 } 109 resultSet := &spannerpb.ResultSet{ 110 Metadata: metadata, 111 Rows: rows, 112 } 113 result := &StatementResult{Type: StatementResultResultSet, ResultSet: resultSet} 114 testSpanner.PutStatementResult(selectSQL, result) 115 116 updateResult := &StatementResult{Type: StatementResultUpdateCount, UpdateCount: updateRowCount} 117 testSpanner.PutStatementResult(updateSQL, updateResult) 118} 119 120func TestSpannerCreateSession(t *testing.T) { 121 testSpanner.Reset() 122 var expectedName = fmt.Sprintf("projects/%s/instances/%s/databases/%s/sessions/", "[PROJECT]", "[INSTANCE]", "[DATABASE]") 123 var formattedDatabase = fmt.Sprintf("projects/%s/instances/%s/databases/%s", "[PROJECT]", "[INSTANCE]", "[DATABASE]") 124 var request = &spannerpb.CreateSessionRequest{ 125 Database: formattedDatabase, 126 } 127 128 c, err := apiv1.NewClient(context.Background(), clientOpt) 129 if err != nil { 130 t.Fatal(err) 131 } 132 resp, err := c.CreateSession(context.Background(), request) 133 if err != nil { 134 t.Fatal(err) 135 } 136 if strings.Index(resp.Name, expectedName) != 0 { 137 t.Errorf("Session name mismatch\nGot: %s\nWant: Name should start with %s)", resp.Name, expectedName) 138 } 139} 140 141func TestSpannerCreateSession_Unavailable(t *testing.T) { 142 testSpanner.Reset() 143 var expectedName = fmt.Sprintf("projects/%s/instances/%s/databases/%s/sessions/", "[PROJECT]", "[INSTANCE]", "[DATABASE]") 144 var formattedDatabase = fmt.Sprintf("projects/%s/instances/%s/databases/%s", "[PROJECT]", "[INSTANCE]", "[DATABASE]") 145 var request = &spannerpb.CreateSessionRequest{ 146 Database: formattedDatabase, 147 } 148 149 c, err := apiv1.NewClient(context.Background(), clientOpt) 150 if err != nil { 151 t.Fatal(err) 152 } 153 testSpanner.SetError(gstatus.Error(codes.Unavailable, "Temporary unavailable")) 154 resp, err := c.CreateSession(context.Background(), request) 155 if err != nil { 156 t.Fatal(err) 157 } 158 if strings.Index(resp.Name, expectedName) != 0 { 159 t.Errorf("Session name mismatch\nGot: %s\nWant: Name should start with %s)", resp.Name, expectedName) 160 } 161} 162 163func TestSpannerGetSession(t *testing.T) { 164 testSpanner.Reset() 165 var formattedDatabase = fmt.Sprintf("projects/%s/instances/%s/databases/%s", "[PROJECT]", "[INSTANCE]", "[DATABASE]") 166 var createRequest = &spannerpb.CreateSessionRequest{ 167 Database: formattedDatabase, 168 } 169 170 c, err := apiv1.NewClient(context.Background(), clientOpt) 171 if err != nil { 172 t.Fatal(err) 173 } 174 createResp, err := c.CreateSession(context.Background(), createRequest) 175 if err != nil { 176 t.Fatal(err) 177 } 178 var getRequest = &spannerpb.GetSessionRequest{ 179 Name: createResp.Name, 180 } 181 getResp, err := c.GetSession(context.Background(), getRequest) 182 if err != nil { 183 t.Fatal(err) 184 } 185 if getResp.Name != getRequest.Name { 186 t.Errorf("Session name mismatch\nGot: %s\nWant: Name should start with %s)", getResp.Name, getRequest.Name) 187 } 188} 189 190func TestSpannerListSessions(t *testing.T) { 191 testSpanner.Reset() 192 const expectedNumberOfSessions = 5 193 var expectedName = fmt.Sprintf("projects/%s/instances/%s/databases/%s/sessions/", "[PROJECT]", "[INSTANCE]", "[DATABASE]") 194 var formattedDatabase = fmt.Sprintf("projects/%s/instances/%s/databases/%s", "[PROJECT]", "[INSTANCE]", "[DATABASE]") 195 var createRequest = &spannerpb.CreateSessionRequest{ 196 Database: formattedDatabase, 197 } 198 199 c, err := apiv1.NewClient(context.Background(), clientOpt) 200 if err != nil { 201 t.Fatal(err) 202 } 203 for i := 0; i < expectedNumberOfSessions; i++ { 204 _, err := c.CreateSession(context.Background(), createRequest) 205 if err != nil { 206 t.Fatal(err) 207 } 208 } 209 var listRequest = &spannerpb.ListSessionsRequest{ 210 Database: formattedDatabase, 211 } 212 var sessionCount int 213 listResp := c.ListSessions(context.Background(), listRequest) 214 for { 215 session, err := listResp.Next() 216 if err == iterator.Done { 217 break 218 } 219 if err != nil { 220 t.Fatal(err) 221 } 222 if strings.Index(session.Name, expectedName) != 0 { 223 t.Errorf("Session name mismatch\nGot: %s\nWant: Name should start with %s)", session.Name, expectedName) 224 } 225 sessionCount++ 226 } 227 if sessionCount != expectedNumberOfSessions { 228 t.Errorf("Session count mismatch\nGot: %d\nWant: %d", sessionCount, expectedNumberOfSessions) 229 } 230} 231 232func TestSpannerDeleteSession(t *testing.T) { 233 testSpanner.Reset() 234 const expectedNumberOfSessions = 5 235 var formattedDatabase = fmt.Sprintf("projects/%s/instances/%s/databases/%s", "[PROJECT]", "[INSTANCE]", "[DATABASE]") 236 var createRequest = &spannerpb.CreateSessionRequest{ 237 Database: formattedDatabase, 238 } 239 240 c, err := apiv1.NewClient(context.Background(), clientOpt) 241 if err != nil { 242 t.Fatal(err) 243 } 244 for i := 0; i < expectedNumberOfSessions; i++ { 245 _, err := c.CreateSession(context.Background(), createRequest) 246 if err != nil { 247 t.Fatal(err) 248 } 249 } 250 var listRequest = &spannerpb.ListSessionsRequest{ 251 Database: formattedDatabase, 252 } 253 var sessionCount int 254 listResp := c.ListSessions(context.Background(), listRequest) 255 for { 256 session, err := listResp.Next() 257 if err == iterator.Done { 258 break 259 } 260 if err != nil { 261 t.Fatal(err) 262 } 263 var deleteRequest = &spannerpb.DeleteSessionRequest{ 264 Name: session.Name, 265 } 266 c.DeleteSession(context.Background(), deleteRequest) 267 sessionCount++ 268 } 269 if sessionCount != expectedNumberOfSessions { 270 t.Errorf("Session count mismatch\nGot: %d\nWant: %d", sessionCount, expectedNumberOfSessions) 271 } 272 // Re-list all sessions. This should now be empty. 273 listResp = c.ListSessions(context.Background(), listRequest) 274 _, err = listResp.Next() 275 if err != iterator.Done { 276 t.Errorf("expected empty session iterator") 277 } 278} 279 280func TestSpannerExecuteSql(t *testing.T) { 281 setup() 282 c, err := apiv1.NewClient(context.Background(), clientOpt) 283 if err != nil { 284 t.Fatal(err) 285 } 286 287 var formattedDatabase = fmt.Sprintf("projects/%s/instances/%s/databases/%s", "[PROJECT]", "[INSTANCE]", "[DATABASE]") 288 var createRequest = &spannerpb.CreateSessionRequest{ 289 Database: formattedDatabase, 290 } 291 session, err := c.CreateSession(context.Background(), createRequest) 292 if err != nil { 293 t.Fatal(err) 294 } 295 request := &spannerpb.ExecuteSqlRequest{ 296 Session: session.Name, 297 Sql: selectSQL, 298 Transaction: &spannerpb.TransactionSelector{ 299 Selector: &spannerpb.TransactionSelector_SingleUse{ 300 SingleUse: &spannerpb.TransactionOptions{ 301 Mode: &spannerpb.TransactionOptions_ReadOnly_{ 302 ReadOnly: &spannerpb.TransactionOptions_ReadOnly{ 303 ReturnReadTimestamp: false, 304 TimestampBound: &spannerpb.TransactionOptions_ReadOnly_Strong{ 305 Strong: true, 306 }, 307 }, 308 }, 309 }, 310 }, 311 }, 312 Seqno: 1, 313 QueryMode: spannerpb.ExecuteSqlRequest_NORMAL, 314 } 315 response, err := c.ExecuteSql(context.Background(), request) 316 if err != nil { 317 t.Fatal(err) 318 } 319 var rowCount int64 320 for _, row := range response.Rows { 321 if len(row.Values) != selectColCount { 322 t.Fatalf("Column count mismatch\nGot: %d\nWant: %d", len(row.Values), selectColCount) 323 } 324 rowCount++ 325 } 326 if rowCount != selectRowCount { 327 t.Fatalf("Row count mismatch\nGot: %d\nWant: %d", rowCount, selectRowCount) 328 } 329} 330 331func TestSpannerExecuteSqlDml(t *testing.T) { 332 setup() 333 c, err := apiv1.NewClient(context.Background(), clientOpt) 334 if err != nil { 335 t.Fatal(err) 336 } 337 338 var formattedDatabase = fmt.Sprintf("projects/%s/instances/%s/databases/%s", "[PROJECT]", "[INSTANCE]", "[DATABASE]") 339 var createRequest = &spannerpb.CreateSessionRequest{ 340 Database: formattedDatabase, 341 } 342 session, err := c.CreateSession(context.Background(), createRequest) 343 if err != nil { 344 t.Fatal(err) 345 } 346 request := &spannerpb.ExecuteSqlRequest{ 347 Session: session.Name, 348 Sql: updateSQL, 349 Transaction: &spannerpb.TransactionSelector{ 350 Selector: &spannerpb.TransactionSelector_Begin{ 351 Begin: &spannerpb.TransactionOptions{ 352 Mode: &spannerpb.TransactionOptions_ReadWrite_{ 353 ReadWrite: &spannerpb.TransactionOptions_ReadWrite{}, 354 }, 355 }, 356 }, 357 }, 358 Seqno: 1, 359 QueryMode: spannerpb.ExecuteSqlRequest_NORMAL, 360 } 361 response, err := c.ExecuteSql(context.Background(), request) 362 if err != nil { 363 t.Fatal(err) 364 } 365 var rowCount int64 = response.Stats.GetRowCountExact() 366 if rowCount != updateRowCount { 367 t.Fatalf("Update count mismatch\nGot: %d\nWant: %d", rowCount, updateRowCount) 368 } 369} 370 371func TestSpannerExecuteStreamingSql(t *testing.T) { 372 setup() 373 c, err := apiv1.NewClient(context.Background(), clientOpt) 374 if err != nil { 375 t.Fatal(err) 376 } 377 378 var formattedDatabase = fmt.Sprintf("projects/%s/instances/%s/databases/%s", "[PROJECT]", "[INSTANCE]", "[DATABASE]") 379 var createRequest = &spannerpb.CreateSessionRequest{ 380 Database: formattedDatabase, 381 } 382 session, err := c.CreateSession(context.Background(), createRequest) 383 if err != nil { 384 t.Fatal(err) 385 } 386 request := &spannerpb.ExecuteSqlRequest{ 387 Session: session.Name, 388 Sql: selectSQL, 389 Transaction: &spannerpb.TransactionSelector{ 390 Selector: &spannerpb.TransactionSelector_SingleUse{ 391 SingleUse: &spannerpb.TransactionOptions{ 392 Mode: &spannerpb.TransactionOptions_ReadOnly_{ 393 ReadOnly: &spannerpb.TransactionOptions_ReadOnly{ 394 ReturnReadTimestamp: false, 395 TimestampBound: &spannerpb.TransactionOptions_ReadOnly_Strong{ 396 Strong: true, 397 }, 398 }, 399 }, 400 }, 401 }, 402 }, 403 Seqno: 1, 404 QueryMode: spannerpb.ExecuteSqlRequest_NORMAL, 405 } 406 response, err := c.ExecuteStreamingSql(context.Background(), request) 407 if err != nil { 408 t.Fatal(err) 409 } 410 var rowIndex int64 411 var colCount int 412 for { 413 for rowIndexInPartial := int64(0); rowIndexInPartial < MaxRowsPerPartialResultSet; rowIndexInPartial++ { 414 partial, err := response.Recv() 415 if err != nil { 416 t.Fatal(err) 417 } 418 if rowIndex == 0 { 419 colCount = len(partial.Metadata.RowType.Fields) 420 if colCount != selectColCount { 421 t.Fatalf("Column count mismatch\nGot: %d\nWant: %d", colCount, selectColCount) 422 } 423 } 424 for col := 0; col < colCount; col++ { 425 pIndex := rowIndexInPartial*int64(colCount) + int64(col) 426 val, err := strconv.ParseInt(partial.Values[pIndex].GetStringValue(), 10, 64) 427 if err != nil { 428 t.Fatalf("Error parsing integer at #%d: %v", pIndex, err) 429 } 430 if val != selectValues[rowIndex] { 431 t.Fatalf("Value mismatch at index %d\nGot: %d\nWant: %d", rowIndex, val, selectValues[rowIndex]) 432 } 433 } 434 rowIndex++ 435 } 436 if rowIndex == selectRowCount { 437 break 438 } 439 } 440 if rowIndex != selectRowCount { 441 t.Fatalf("Row count mismatch\nGot: %d\nWant: %d", rowIndex, selectRowCount) 442 } 443} 444 445func TestSpannerExecuteBatchDml(t *testing.T) { 446 setup() 447 c, err := apiv1.NewClient(context.Background(), clientOpt) 448 if err != nil { 449 t.Fatal(err) 450 } 451 452 var formattedDatabase = fmt.Sprintf("projects/%s/instances/%s/databases/%s", "[PROJECT]", "[INSTANCE]", "[DATABASE]") 453 var createRequest = &spannerpb.CreateSessionRequest{ 454 Database: formattedDatabase, 455 } 456 session, err := c.CreateSession(context.Background(), createRequest) 457 if err != nil { 458 t.Fatal(err) 459 } 460 statements := make([]*spannerpb.ExecuteBatchDmlRequest_Statement, 3) 461 for idx := 0; idx < len(statements); idx++ { 462 statements[idx] = &spannerpb.ExecuteBatchDmlRequest_Statement{Sql: updateSQL} 463 } 464 executeBatchDmlRequest := &spannerpb.ExecuteBatchDmlRequest{ 465 Session: session.Name, 466 Statements: statements, 467 Transaction: &spannerpb.TransactionSelector{ 468 Selector: &spannerpb.TransactionSelector_Begin{ 469 Begin: &spannerpb.TransactionOptions{ 470 Mode: &spannerpb.TransactionOptions_ReadWrite_{ 471 ReadWrite: &spannerpb.TransactionOptions_ReadWrite{}, 472 }, 473 }, 474 }, 475 }, 476 Seqno: 1, 477 } 478 response, err := c.ExecuteBatchDml(context.Background(), executeBatchDmlRequest) 479 if err != nil { 480 t.Fatal(err) 481 } 482 var totalRowCount int64 483 for _, res := range response.ResultSets { 484 var rowCount int64 = res.Stats.GetRowCountExact() 485 if rowCount != updateRowCount { 486 t.Fatalf("Update count mismatch\nGot: %d\nWant: %d", rowCount, updateRowCount) 487 } 488 totalRowCount += rowCount 489 } 490 if totalRowCount != updateRowCount*int64(len(statements)) { 491 t.Fatalf("Total update count mismatch\nGot: %d\nWant: %d", totalRowCount, updateRowCount*int64(len(statements))) 492 } 493} 494 495func TestBeginTransaction(t *testing.T) { 496 setup() 497 c, err := apiv1.NewClient(context.Background(), clientOpt) 498 if err != nil { 499 t.Fatal(err) 500 } 501 502 var formattedDatabase = fmt.Sprintf("projects/%s/instances/%s/databases/%s", "[PROJECT]", "[INSTANCE]", "[DATABASE]") 503 var createRequest = &spannerpb.CreateSessionRequest{ 504 Database: formattedDatabase, 505 } 506 session, err := c.CreateSession(context.Background(), createRequest) 507 if err != nil { 508 t.Fatal(err) 509 } 510 beginRequest := &spannerpb.BeginTransactionRequest{ 511 Session: session.Name, 512 Options: &spannerpb.TransactionOptions{ 513 Mode: &spannerpb.TransactionOptions_ReadWrite_{ 514 ReadWrite: &spannerpb.TransactionOptions_ReadWrite{}, 515 }, 516 }, 517 } 518 tx, err := c.BeginTransaction(context.Background(), beginRequest) 519 if err != nil { 520 t.Fatal(err) 521 } 522 expectedName := fmt.Sprintf("%s/transactions/", session.Name) 523 if strings.Index(string(tx.Id), expectedName) != 0 { 524 t.Errorf("Transaction name mismatch\nGot: %s\nWant: Name should start with %s)", string(tx.Id), expectedName) 525 } 526} 527 528func TestCommitTransaction(t *testing.T) { 529 setup() 530 c, err := apiv1.NewClient(context.Background(), clientOpt) 531 if err != nil { 532 t.Fatal(err) 533 } 534 535 var formattedDatabase = fmt.Sprintf("projects/%s/instances/%s/databases/%s", "[PROJECT]", "[INSTANCE]", "[DATABASE]") 536 var createRequest = &spannerpb.CreateSessionRequest{ 537 Database: formattedDatabase, 538 } 539 session, err := c.CreateSession(context.Background(), createRequest) 540 if err != nil { 541 t.Fatal(err) 542 } 543 beginRequest := &spannerpb.BeginTransactionRequest{ 544 Session: session.Name, 545 Options: &spannerpb.TransactionOptions{ 546 Mode: &spannerpb.TransactionOptions_ReadWrite_{ 547 ReadWrite: &spannerpb.TransactionOptions_ReadWrite{}, 548 }, 549 }, 550 } 551 tx, err := c.BeginTransaction(context.Background(), beginRequest) 552 if err != nil { 553 t.Fatal(err) 554 } 555 commitRequest := &spannerpb.CommitRequest{ 556 Session: session.Name, 557 Transaction: &spannerpb.CommitRequest_TransactionId{ 558 TransactionId: tx.Id, 559 }, 560 } 561 resp, err := c.Commit(context.Background(), commitRequest) 562 if err != nil { 563 t.Fatal(err) 564 } 565 if resp.CommitTimestamp == nil { 566 t.Fatalf("No commit timestamp returned") 567 } 568} 569 570func TestRollbackTransaction(t *testing.T) { 571 setup() 572 c, err := apiv1.NewClient(context.Background(), clientOpt) 573 if err != nil { 574 t.Fatal(err) 575 } 576 577 var formattedDatabase = fmt.Sprintf("projects/%s/instances/%s/databases/%s", "[PROJECT]", "[INSTANCE]", "[DATABASE]") 578 var createRequest = &spannerpb.CreateSessionRequest{ 579 Database: formattedDatabase, 580 } 581 session, err := c.CreateSession(context.Background(), createRequest) 582 if err != nil { 583 t.Fatal(err) 584 } 585 beginRequest := &spannerpb.BeginTransactionRequest{ 586 Session: session.Name, 587 Options: &spannerpb.TransactionOptions{ 588 Mode: &spannerpb.TransactionOptions_ReadWrite_{ 589 ReadWrite: &spannerpb.TransactionOptions_ReadWrite{}, 590 }, 591 }, 592 } 593 tx, err := c.BeginTransaction(context.Background(), beginRequest) 594 if err != nil { 595 t.Fatal(err) 596 } 597 rollbackRequest := &spannerpb.RollbackRequest{ 598 Session: session.Name, 599 TransactionId: tx.Id, 600 } 601 err = c.Rollback(context.Background(), rollbackRequest) 602 if err != nil { 603 t.Fatal(err) 604 } 605} 606