1// Copyright 2019 Google LLC 2// 3// Licensed under the Apache License, Version 2.0 (the "License"); 4// you may not use this file except in compliance with the License. 5// You may obtain a copy of the License at 6// 7// https://www.apache.org/licenses/LICENSE-2.0 8// 9// Unless required by applicable law or agreed to in writing, software 10// distributed under the License is distributed on an "AS IS" BASIS, 11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12// See the License for the specific language governing permissions and 13// limitations under the License. 14 15package testutil 16 17import ( 18 "bytes" 19 "context" 20 "fmt" 21 "math/rand" 22 "sort" 23 "strings" 24 "sync" 25 "time" 26 27 emptypb "github.com/golang/protobuf/ptypes/empty" 28 structpb "github.com/golang/protobuf/ptypes/struct" 29 "github.com/golang/protobuf/ptypes/timestamp" 30 "google.golang.org/genproto/googleapis/rpc/status" 31 spannerpb "google.golang.org/genproto/googleapis/spanner/v1" 32 "google.golang.org/grpc/codes" 33 gstatus "google.golang.org/grpc/status" 34) 35 36// StatementResultType indicates the type of result returned by a SQL 37// statement. 38type StatementResultType int 39 40const ( 41 // StatementResultError indicates that the sql statement returns an error. 42 StatementResultError StatementResultType = 0 43 // StatementResultResultSet indicates that the sql statement returns a 44 // result set. 45 StatementResultResultSet StatementResultType = 1 46 // StatementResultUpdateCount indicates that the sql statement returns an 47 // update count. 48 StatementResultUpdateCount StatementResultType = 2 49 // MaxRowsPerPartialResultSet is the maximum number of rows returned in 50 // each PartialResultSet. This number is deliberately set to a low value to 51 // ensure that most queries return more than one PartialResultSet. 52 MaxRowsPerPartialResultSet = 1 53) 54 55// The method names that can be used to register execution times and errors. 56const ( 57 MethodBeginTransaction string = "BEGIN_TRANSACTION" 58 MethodCommitTransaction string = "COMMIT_TRANSACTION" 59 MethodBatchCreateSession string = "BATCH_CREATE_SESSION" 60 MethodCreateSession string = "CREATE_SESSION" 61 MethodDeleteSession string = "DELETE_SESSION" 62 MethodGetSession string = "GET_SESSION" 63 MethodExecuteSql string = "EXECUTE_SQL" 64 MethodExecuteStreamingSql string = "EXECUTE_STREAMING_SQL" 65) 66 67// StatementResult represents a mocked result on the test server. The result is 68// either of: a ResultSet, an update count or an error. 69type StatementResult struct { 70 Type StatementResultType 71 Err error 72 ResultSet *spannerpb.ResultSet 73 UpdateCount int64 74} 75 76// PartialResultSetExecutionTime represents execution times and errors that 77// should be used when a PartialResult at the specified resume token is to 78// be returned. 79type PartialResultSetExecutionTime struct { 80 ResumeToken []byte 81 ExecutionTime time.Duration 82 Err error 83} 84 85// Converts a ResultSet to a PartialResultSet. This method is used to convert 86// a mocked result to a PartialResultSet when one of the streaming methods are 87// called. 88func (s *StatementResult) toPartialResultSets(resumeToken []byte) (result []*spannerpb.PartialResultSet, err error) { 89 var startIndex uint64 90 if len(resumeToken) > 0 { 91 if startIndex, err = DecodeResumeToken(resumeToken); err != nil { 92 return nil, err 93 } 94 } 95 96 totalRows := uint64(len(s.ResultSet.Rows)) 97 for { 98 rowCount := min(totalRows-startIndex, uint64(MaxRowsPerPartialResultSet)) 99 rows := s.ResultSet.Rows[startIndex : startIndex+rowCount] 100 values := make([]*structpb.Value, 101 len(rows)*len(s.ResultSet.Metadata.RowType.Fields)) 102 var idx int 103 for _, row := range rows { 104 for colIdx := range s.ResultSet.Metadata.RowType.Fields { 105 values[idx] = row.Values[colIdx] 106 idx++ 107 } 108 } 109 result = append(result, &spannerpb.PartialResultSet{ 110 Metadata: s.ResultSet.Metadata, 111 Values: values, 112 ResumeToken: EncodeResumeToken(startIndex + rowCount), 113 }) 114 startIndex += rowCount 115 if startIndex == totalRows { 116 break 117 } 118 } 119 return result, nil 120} 121 122func min(x, y uint64) uint64 { 123 if x > y { 124 return y 125 } 126 return x 127} 128 129func (s *StatementResult) updateCountToPartialResultSet(exact bool) *spannerpb.PartialResultSet { 130 return &spannerpb.PartialResultSet{ 131 Stats: s.convertUpdateCountToResultSet(exact).Stats, 132 } 133} 134 135// Converts an update count to a ResultSet, as DML statements also return the 136// update count as the statistics of a ResultSet. 137func (s *StatementResult) convertUpdateCountToResultSet(exact bool) *spannerpb.ResultSet { 138 if exact { 139 return &spannerpb.ResultSet{ 140 Stats: &spannerpb.ResultSetStats{ 141 RowCount: &spannerpb.ResultSetStats_RowCountExact{ 142 RowCountExact: s.UpdateCount, 143 }, 144 }, 145 } 146 } 147 return &spannerpb.ResultSet{ 148 Stats: &spannerpb.ResultSetStats{ 149 RowCount: &spannerpb.ResultSetStats_RowCountLowerBound{ 150 RowCountLowerBound: s.UpdateCount, 151 }, 152 }, 153 } 154} 155 156// SimulatedExecutionTime represents the time the execution of a method 157// should take, and any errors that should be returned by the method. 158type SimulatedExecutionTime struct { 159 MinimumExecutionTime time.Duration 160 RandomExecutionTime time.Duration 161 Errors []error 162 // Keep error after execution. The error will continue to be returned until 163 // it is cleared. 164 KeepError bool 165} 166 167// InMemSpannerServer contains the SpannerServer interface plus a couple 168// of specific methods for adding mocked results and resetting the server. 169type InMemSpannerServer interface { 170 spannerpb.SpannerServer 171 172 // Stops this server. 173 Stop() 174 175 // Resets the in-mem server to its default state, deleting all sessions and 176 // transactions that have been created on the server. Mocked results are 177 // not deleted. 178 Reset() 179 180 // Sets an error that will be returned by the next server call. The server 181 // call will also automatically clear the error. 182 SetError(err error) 183 184 // Puts a mocked result on the server for a specific sql statement. The 185 // server does not parse the SQL string in any way, it is merely used as 186 // a key to the mocked result. The result will be used for all methods that 187 // expect a SQL statement, including (batch) DML methods. 188 PutStatementResult(sql string, result *StatementResult) error 189 190 // Adds a PartialResultSetExecutionTime to the server that should be returned 191 // for the specified SQL string. 192 AddPartialResultSetError(sql string, err PartialResultSetExecutionTime) 193 194 // Removes a mocked result on the server for a specific sql statement. 195 RemoveStatementResult(sql string) 196 197 // Aborts the specified transaction . This method can be used to test 198 // transaction retry logic. 199 AbortTransaction(id []byte) 200 201 // Puts a simulated execution time for one of the Spanner methods. 202 PutExecutionTime(method string, executionTime SimulatedExecutionTime) 203 // Freeze stalls all requests. 204 Freeze() 205 // Unfreeze restores processing requests. 206 Unfreeze() 207 208 TotalSessionsCreated() uint 209 TotalSessionsDeleted() uint 210 SetMaxSessionsReturnedByServerPerBatchRequest(sessionCount int32) 211 SetMaxSessionsReturnedByServerInTotal(sessionCount int32) 212 213 ReceivedRequests() chan interface{} 214 DumpSessions() map[string]bool 215 ClearPings() 216 DumpPings() []string 217} 218 219type inMemSpannerServer struct { 220 // Embed for forward compatibility. 221 // Tests will keep working if more methods are added 222 // in the future. 223 spannerpb.SpannerServer 224 225 mu sync.Mutex 226 // Set to true when this server been stopped. This is the end state of a 227 // server, a stopped server cannot be restarted. 228 stopped bool 229 // If set, all calls return this error. 230 err error 231 // The mock server creates session IDs using this counter. 232 sessionCounter uint64 233 // The sessions that have been created on this mock server. 234 sessions map[string]*spannerpb.Session 235 // Last use times per session. 236 sessionLastUseTime map[string]time.Time 237 // The mock server creates transaction IDs per session using these 238 // counters. 239 transactionCounters map[string]*uint64 240 // The transactions that have been created on this mock server. 241 transactions map[string]*spannerpb.Transaction 242 // The transactions that have been (manually) aborted on the server. 243 abortedTransactions map[string]bool 244 // The transactions that are marked as PartitionedDMLTransaction 245 partitionedDmlTransactions map[string]bool 246 // The mocked results for this server. 247 statementResults map[string]*StatementResult 248 // The simulated execution times per method. 249 executionTimes map[string]*SimulatedExecutionTime 250 // The simulated errors for partial result sets 251 partialResultSetErrors map[string][]*PartialResultSetExecutionTime 252 253 totalSessionsCreated uint 254 totalSessionsDeleted uint 255 // The maximum number of sessions that will be created per batch request. 256 maxSessionsReturnedByServerPerBatchRequest int32 257 maxSessionsReturnedByServerInTotal int32 258 receivedRequests chan interface{} 259 // Session ping history. 260 pings []string 261 262 // Server will stall on any requests. 263 freezed chan struct{} 264} 265 266// NewInMemSpannerServer creates a new in-mem test server. 267func NewInMemSpannerServer() InMemSpannerServer { 268 res := &inMemSpannerServer{} 269 res.initDefaults() 270 res.statementResults = make(map[string]*StatementResult) 271 res.executionTimes = make(map[string]*SimulatedExecutionTime) 272 res.partialResultSetErrors = make(map[string][]*PartialResultSetExecutionTime) 273 res.receivedRequests = make(chan interface{}, 1000000) 274 // Produce a closed channel, so the default action of ready is to not block. 275 res.Freeze() 276 res.Unfreeze() 277 return res 278} 279 280func (s *inMemSpannerServer) Stop() { 281 s.mu.Lock() 282 defer s.mu.Unlock() 283 s.stopped = true 284 close(s.receivedRequests) 285} 286 287// Resets the test server to its initial state, deleting all sessions and 288// transactions that have been created on the server. This method will not 289// remove mocked results. 290func (s *inMemSpannerServer) Reset() { 291 s.mu.Lock() 292 defer s.mu.Unlock() 293 close(s.receivedRequests) 294 s.receivedRequests = make(chan interface{}, 1000000) 295 s.initDefaults() 296} 297 298func (s *inMemSpannerServer) SetError(err error) { 299 s.mu.Lock() 300 defer s.mu.Unlock() 301 s.err = err 302} 303 304// Registers a mocked result for a SQL statement on the server. 305func (s *inMemSpannerServer) PutStatementResult(sql string, result *StatementResult) error { 306 s.mu.Lock() 307 defer s.mu.Unlock() 308 s.statementResults[sql] = result 309 return nil 310} 311 312func (s *inMemSpannerServer) RemoveStatementResult(sql string) { 313 s.mu.Lock() 314 defer s.mu.Unlock() 315 delete(s.statementResults, sql) 316} 317 318func (s *inMemSpannerServer) AbortTransaction(id []byte) { 319 s.mu.Lock() 320 defer s.mu.Unlock() 321 s.abortedTransactions[string(id)] = true 322} 323 324func (s *inMemSpannerServer) PutExecutionTime(method string, executionTime SimulatedExecutionTime) { 325 s.mu.Lock() 326 defer s.mu.Unlock() 327 s.executionTimes[method] = &executionTime 328} 329 330func (s *inMemSpannerServer) AddPartialResultSetError(sql string, partialResultSetError PartialResultSetExecutionTime) { 331 s.mu.Lock() 332 defer s.mu.Unlock() 333 s.partialResultSetErrors[sql] = append(s.partialResultSetErrors[sql], &partialResultSetError) 334} 335 336// Freeze stalls all requests. 337func (s *inMemSpannerServer) Freeze() { 338 s.mu.Lock() 339 defer s.mu.Unlock() 340 s.freezed = make(chan struct{}) 341} 342 343// Unfreeze restores processing requests. 344func (s *inMemSpannerServer) Unfreeze() { 345 s.mu.Lock() 346 defer s.mu.Unlock() 347 close(s.freezed) 348} 349 350// ready checks conditions before executing requests 351func (s *inMemSpannerServer) ready() { 352 s.mu.Lock() 353 freezed := s.freezed 354 s.mu.Unlock() 355 // check if server should be freezed 356 <-freezed 357} 358 359func (s *inMemSpannerServer) TotalSessionsCreated() uint { 360 s.mu.Lock() 361 defer s.mu.Unlock() 362 return s.totalSessionsCreated 363} 364 365func (s *inMemSpannerServer) TotalSessionsDeleted() uint { 366 s.mu.Lock() 367 defer s.mu.Unlock() 368 return s.totalSessionsDeleted 369} 370 371func (s *inMemSpannerServer) SetMaxSessionsReturnedByServerPerBatchRequest(sessionCount int32) { 372 s.mu.Lock() 373 defer s.mu.Unlock() 374 s.maxSessionsReturnedByServerPerBatchRequest = sessionCount 375} 376 377func (s *inMemSpannerServer) SetMaxSessionsReturnedByServerInTotal(sessionCount int32) { 378 s.mu.Lock() 379 defer s.mu.Unlock() 380 s.maxSessionsReturnedByServerInTotal = sessionCount 381} 382 383func (s *inMemSpannerServer) ReceivedRequests() chan interface{} { 384 return s.receivedRequests 385} 386 387// ClearPings clears the ping history from the server. 388func (s *inMemSpannerServer) ClearPings() { 389 s.mu.Lock() 390 defer s.mu.Unlock() 391 s.pings = nil 392} 393 394// DumpPings dumps the ping history. 395func (s *inMemSpannerServer) DumpPings() []string { 396 s.mu.Lock() 397 defer s.mu.Unlock() 398 return append([]string(nil), s.pings...) 399} 400 401// DumpSessions dumps the internal session table. 402func (s *inMemSpannerServer) DumpSessions() map[string]bool { 403 s.mu.Lock() 404 defer s.mu.Unlock() 405 st := map[string]bool{} 406 for s := range s.sessions { 407 st[s] = true 408 } 409 return st 410} 411 412func (s *inMemSpannerServer) initDefaults() { 413 s.sessionCounter = 0 414 s.maxSessionsReturnedByServerPerBatchRequest = 100 415 s.sessions = make(map[string]*spannerpb.Session) 416 s.sessionLastUseTime = make(map[string]time.Time) 417 s.transactions = make(map[string]*spannerpb.Transaction) 418 s.abortedTransactions = make(map[string]bool) 419 s.partitionedDmlTransactions = make(map[string]bool) 420 s.transactionCounters = make(map[string]*uint64) 421} 422 423func (s *inMemSpannerServer) generateSessionNameLocked(database string) string { 424 s.sessionCounter++ 425 return fmt.Sprintf("%s/sessions/%d", database, s.sessionCounter) 426} 427 428func (s *inMemSpannerServer) findSession(name string) (*spannerpb.Session, error) { 429 s.mu.Lock() 430 defer s.mu.Unlock() 431 session := s.sessions[name] 432 if session == nil { 433 return nil, gstatus.Error(codes.NotFound, fmt.Sprintf("Session %s not found", name)) 434 } 435 return session, nil 436} 437 438func (s *inMemSpannerServer) updateSessionLastUseTime(session string) { 439 s.mu.Lock() 440 defer s.mu.Unlock() 441 s.sessionLastUseTime[session] = time.Now() 442} 443 444func getCurrentTimestamp() *timestamp.Timestamp { 445 t := time.Now() 446 return ×tamp.Timestamp{Seconds: t.Unix(), Nanos: int32(t.Nanosecond())} 447} 448 449// Gets the transaction id from the transaction selector. If the selector 450// specifies that a new transaction should be started, this method will start 451// a new transaction and return the id of that transaction. 452func (s *inMemSpannerServer) getTransactionID(session *spannerpb.Session, txSelector *spannerpb.TransactionSelector) []byte { 453 var res []byte 454 if txSelector.GetBegin() != nil { 455 // Start a new transaction. 456 res = s.beginTransaction(session, txSelector.GetBegin()).Id 457 } else if txSelector.GetId() != nil { 458 res = txSelector.GetId() 459 } 460 return res 461} 462 463func (s *inMemSpannerServer) generateTransactionName(session string) string { 464 s.mu.Lock() 465 defer s.mu.Unlock() 466 counter, ok := s.transactionCounters[session] 467 if !ok { 468 counter = new(uint64) 469 s.transactionCounters[session] = counter 470 } 471 *counter++ 472 return fmt.Sprintf("%s/transactions/%d", session, *counter) 473} 474 475func (s *inMemSpannerServer) beginTransaction(session *spannerpb.Session, options *spannerpb.TransactionOptions) *spannerpb.Transaction { 476 id := s.generateTransactionName(session.Name) 477 res := &spannerpb.Transaction{ 478 Id: []byte(id), 479 ReadTimestamp: getCurrentTimestamp(), 480 } 481 s.mu.Lock() 482 s.transactions[id] = res 483 s.partitionedDmlTransactions[id] = options.GetPartitionedDml() != nil 484 s.mu.Unlock() 485 return res 486} 487 488func (s *inMemSpannerServer) getTransactionByID(id []byte) (*spannerpb.Transaction, error) { 489 s.mu.Lock() 490 defer s.mu.Unlock() 491 tx, ok := s.transactions[string(id)] 492 if !ok { 493 return nil, gstatus.Error(codes.NotFound, "Transaction not found") 494 } 495 aborted, ok := s.abortedTransactions[string(id)] 496 if ok && aborted { 497 return nil, gstatus.Error(codes.Aborted, "Transaction has been aborted") 498 } 499 return tx, nil 500} 501 502func (s *inMemSpannerServer) removeTransaction(tx *spannerpb.Transaction) { 503 s.mu.Lock() 504 defer s.mu.Unlock() 505 delete(s.transactions, string(tx.Id)) 506 delete(s.partitionedDmlTransactions, string(tx.Id)) 507} 508 509func (s *inMemSpannerServer) getStatementResult(sql string) (*StatementResult, error) { 510 s.mu.Lock() 511 defer s.mu.Unlock() 512 result, ok := s.statementResults[sql] 513 if !ok { 514 return nil, gstatus.Error(codes.Internal, fmt.Sprintf("No result found for statement %v", sql)) 515 } 516 return result, nil 517} 518 519func (s *inMemSpannerServer) simulateExecutionTime(method string, req interface{}) error { 520 s.mu.Lock() 521 if s.stopped { 522 s.mu.Unlock() 523 return gstatus.Error(codes.Unavailable, "server has been stopped") 524 } 525 s.receivedRequests <- req 526 s.mu.Unlock() 527 s.ready() 528 s.mu.Lock() 529 if s.err != nil { 530 err := s.err 531 s.err = nil 532 s.mu.Unlock() 533 return err 534 } 535 executionTime, ok := s.executionTimes[method] 536 s.mu.Unlock() 537 if ok { 538 var randTime int64 539 if executionTime.RandomExecutionTime > 0 { 540 randTime = rand.Int63n(int64(executionTime.RandomExecutionTime)) 541 } 542 totalExecutionTime := time.Duration(int64(executionTime.MinimumExecutionTime) + randTime) 543 <-time.After(totalExecutionTime) 544 s.mu.Lock() 545 if executionTime.Errors != nil && len(executionTime.Errors) > 0 { 546 err := executionTime.Errors[0] 547 if !executionTime.KeepError { 548 executionTime.Errors = executionTime.Errors[1:] 549 } 550 s.mu.Unlock() 551 return err 552 } 553 s.mu.Unlock() 554 } 555 return nil 556} 557 558func (s *inMemSpannerServer) CreateSession(ctx context.Context, req *spannerpb.CreateSessionRequest) (*spannerpb.Session, error) { 559 if err := s.simulateExecutionTime(MethodCreateSession, req); err != nil { 560 return nil, err 561 } 562 if req.Database == "" { 563 return nil, gstatus.Error(codes.InvalidArgument, "Missing database") 564 } 565 s.mu.Lock() 566 defer s.mu.Unlock() 567 if s.maxSessionsReturnedByServerInTotal > int32(0) && int32(len(s.sessions)) == s.maxSessionsReturnedByServerInTotal { 568 return nil, gstatus.Error(codes.ResourceExhausted, "No more sessions available") 569 } 570 sessionName := s.generateSessionNameLocked(req.Database) 571 ts := getCurrentTimestamp() 572 session := &spannerpb.Session{Name: sessionName, CreateTime: ts, ApproximateLastUseTime: ts} 573 s.totalSessionsCreated++ 574 s.sessions[sessionName] = session 575 return session, nil 576} 577 578func (s *inMemSpannerServer) BatchCreateSessions(ctx context.Context, req *spannerpb.BatchCreateSessionsRequest) (*spannerpb.BatchCreateSessionsResponse, error) { 579 if err := s.simulateExecutionTime(MethodBatchCreateSession, req); err != nil { 580 return nil, err 581 } 582 if req.Database == "" { 583 return nil, gstatus.Error(codes.InvalidArgument, "Missing database") 584 } 585 if req.SessionCount <= 0 { 586 return nil, gstatus.Error(codes.InvalidArgument, "Session count must be >= 0") 587 } 588 sessionsToCreate := req.SessionCount 589 s.mu.Lock() 590 defer s.mu.Unlock() 591 if s.maxSessionsReturnedByServerInTotal > int32(0) && int32(len(s.sessions)) >= s.maxSessionsReturnedByServerInTotal { 592 return nil, gstatus.Error(codes.ResourceExhausted, "No more sessions available") 593 } 594 if sessionsToCreate > s.maxSessionsReturnedByServerPerBatchRequest { 595 sessionsToCreate = s.maxSessionsReturnedByServerPerBatchRequest 596 } 597 if s.maxSessionsReturnedByServerInTotal > int32(0) && (sessionsToCreate+int32(len(s.sessions))) > s.maxSessionsReturnedByServerInTotal { 598 sessionsToCreate = s.maxSessionsReturnedByServerInTotal - int32(len(s.sessions)) 599 } 600 sessions := make([]*spannerpb.Session, sessionsToCreate) 601 for i := int32(0); i < sessionsToCreate; i++ { 602 sessionName := s.generateSessionNameLocked(req.Database) 603 ts := getCurrentTimestamp() 604 sessions[i] = &spannerpb.Session{Name: sessionName, CreateTime: ts, ApproximateLastUseTime: ts} 605 s.totalSessionsCreated++ 606 s.sessions[sessionName] = sessions[i] 607 } 608 return &spannerpb.BatchCreateSessionsResponse{Session: sessions}, nil 609} 610 611func (s *inMemSpannerServer) GetSession(ctx context.Context, req *spannerpb.GetSessionRequest) (*spannerpb.Session, error) { 612 if err := s.simulateExecutionTime(MethodGetSession, req); err != nil { 613 return nil, err 614 } 615 s.mu.Lock() 616 s.pings = append(s.pings, req.Name) 617 s.mu.Unlock() 618 if req.Name == "" { 619 return nil, gstatus.Error(codes.InvalidArgument, "Missing session name") 620 } 621 session, err := s.findSession(req.Name) 622 if err != nil { 623 return nil, err 624 } 625 return session, nil 626} 627 628func (s *inMemSpannerServer) ListSessions(ctx context.Context, req *spannerpb.ListSessionsRequest) (*spannerpb.ListSessionsResponse, error) { 629 s.mu.Lock() 630 if s.stopped { 631 s.mu.Unlock() 632 return nil, gstatus.Error(codes.Unavailable, "server has been stopped") 633 } 634 s.receivedRequests <- req 635 s.mu.Unlock() 636 if req.Database == "" { 637 return nil, gstatus.Error(codes.InvalidArgument, "Missing database") 638 } 639 expectedSessionName := req.Database + "/sessions/" 640 var sessions []*spannerpb.Session 641 s.mu.Lock() 642 for _, session := range s.sessions { 643 if strings.Index(session.Name, expectedSessionName) == 0 { 644 sessions = append(sessions, session) 645 } 646 } 647 s.mu.Unlock() 648 sort.Slice(sessions[:], func(i, j int) bool { 649 return sessions[i].Name < sessions[j].Name 650 }) 651 res := &spannerpb.ListSessionsResponse{Sessions: sessions} 652 return res, nil 653} 654 655func (s *inMemSpannerServer) DeleteSession(ctx context.Context, req *spannerpb.DeleteSessionRequest) (*emptypb.Empty, error) { 656 if err := s.simulateExecutionTime(MethodDeleteSession, req); err != nil { 657 return nil, err 658 } 659 if req.Name == "" { 660 return nil, gstatus.Error(codes.InvalidArgument, "Missing session name") 661 } 662 if _, err := s.findSession(req.Name); err != nil { 663 return nil, err 664 } 665 s.mu.Lock() 666 defer s.mu.Unlock() 667 s.totalSessionsDeleted++ 668 delete(s.sessions, req.Name) 669 return &emptypb.Empty{}, nil 670} 671 672func (s *inMemSpannerServer) ExecuteSql(ctx context.Context, req *spannerpb.ExecuteSqlRequest) (*spannerpb.ResultSet, error) { 673 if err := s.simulateExecutionTime(MethodExecuteSql, req); err != nil { 674 return nil, err 675 } 676 if req.Session == "" { 677 return nil, gstatus.Error(codes.InvalidArgument, "Missing session name") 678 } 679 session, err := s.findSession(req.Session) 680 if err != nil { 681 return nil, err 682 } 683 var id []byte 684 s.updateSessionLastUseTime(session.Name) 685 if id = s.getTransactionID(session, req.Transaction); id != nil { 686 _, err = s.getTransactionByID(id) 687 if err != nil { 688 return nil, err 689 } 690 } 691 statementResult, err := s.getStatementResult(req.Sql) 692 if err != nil { 693 return nil, err 694 } 695 s.mu.Lock() 696 isPartitionedDml := s.partitionedDmlTransactions[string(id)] 697 s.mu.Unlock() 698 switch statementResult.Type { 699 case StatementResultError: 700 return nil, statementResult.Err 701 case StatementResultResultSet: 702 return statementResult.ResultSet, nil 703 case StatementResultUpdateCount: 704 return statementResult.convertUpdateCountToResultSet(!isPartitionedDml), nil 705 } 706 return nil, gstatus.Error(codes.Internal, "Unknown result type") 707} 708 709func (s *inMemSpannerServer) ExecuteStreamingSql(req *spannerpb.ExecuteSqlRequest, stream spannerpb.Spanner_ExecuteStreamingSqlServer) error { 710 if err := s.simulateExecutionTime(MethodExecuteStreamingSql, req); err != nil { 711 return err 712 } 713 if req.Session == "" { 714 return gstatus.Error(codes.InvalidArgument, "Missing session name") 715 } 716 session, err := s.findSession(req.Session) 717 if err != nil { 718 return err 719 } 720 s.updateSessionLastUseTime(session.Name) 721 var id []byte 722 if id = s.getTransactionID(session, req.Transaction); id != nil { 723 _, err = s.getTransactionByID(id) 724 if err != nil { 725 return err 726 } 727 } 728 statementResult, err := s.getStatementResult(req.Sql) 729 if err != nil { 730 return err 731 } 732 s.mu.Lock() 733 isPartitionedDml := s.partitionedDmlTransactions[string(id)] 734 s.mu.Unlock() 735 switch statementResult.Type { 736 case StatementResultError: 737 return statementResult.Err 738 case StatementResultResultSet: 739 parts, err := statementResult.toPartialResultSets(req.ResumeToken) 740 if err != nil { 741 return err 742 } 743 var nextPartialResultSetError *PartialResultSetExecutionTime 744 s.mu.Lock() 745 pErrors := s.partialResultSetErrors[req.Sql] 746 if len(pErrors) > 0 { 747 nextPartialResultSetError = pErrors[0] 748 s.partialResultSetErrors[req.Sql] = pErrors[1:] 749 } 750 s.mu.Unlock() 751 for _, part := range parts { 752 if nextPartialResultSetError != nil && bytes.Equal(part.ResumeToken, nextPartialResultSetError.ResumeToken) { 753 if nextPartialResultSetError.ExecutionTime > 0 { 754 <-time.After(nextPartialResultSetError.ExecutionTime) 755 } 756 if nextPartialResultSetError.Err != nil { 757 return nextPartialResultSetError.Err 758 } 759 } 760 if err := stream.Send(part); err != nil { 761 return err 762 } 763 } 764 return nil 765 case StatementResultUpdateCount: 766 part := statementResult.updateCountToPartialResultSet(!isPartitionedDml) 767 if err := stream.Send(part); err != nil { 768 return err 769 } 770 return nil 771 } 772 return gstatus.Error(codes.Internal, "Unknown result type") 773} 774 775func (s *inMemSpannerServer) ExecuteBatchDml(ctx context.Context, req *spannerpb.ExecuteBatchDmlRequest) (*spannerpb.ExecuteBatchDmlResponse, error) { 776 s.mu.Lock() 777 if s.stopped { 778 s.mu.Unlock() 779 return nil, gstatus.Error(codes.Unavailable, "server has been stopped") 780 } 781 s.receivedRequests <- req 782 s.mu.Unlock() 783 if req.Session == "" { 784 return nil, gstatus.Error(codes.InvalidArgument, "Missing session name") 785 } 786 session, err := s.findSession(req.Session) 787 if err != nil { 788 return nil, err 789 } 790 s.updateSessionLastUseTime(session.Name) 791 var id []byte 792 if id = s.getTransactionID(session, req.Transaction); id != nil { 793 _, err = s.getTransactionByID(id) 794 if err != nil { 795 return nil, err 796 } 797 } 798 s.mu.Lock() 799 isPartitionedDml := s.partitionedDmlTransactions[string(id)] 800 s.mu.Unlock() 801 resp := &spannerpb.ExecuteBatchDmlResponse{} 802 resp.ResultSets = make([]*spannerpb.ResultSet, len(req.Statements)) 803 for idx, batchStatement := range req.Statements { 804 statementResult, err := s.getStatementResult(batchStatement.Sql) 805 if err != nil { 806 return nil, err 807 } 808 switch statementResult.Type { 809 case StatementResultError: 810 resp.Status = &status.Status{Code: int32(codes.Unknown)} 811 case StatementResultResultSet: 812 return nil, gstatus.Error(codes.InvalidArgument, fmt.Sprintf("Not an update statement: %v", batchStatement.Sql)) 813 case StatementResultUpdateCount: 814 resp.ResultSets[idx] = statementResult.convertUpdateCountToResultSet(!isPartitionedDml) 815 resp.Status = &status.Status{Code: int32(codes.OK)} 816 } 817 } 818 return resp, nil 819} 820 821func (s *inMemSpannerServer) Read(ctx context.Context, req *spannerpb.ReadRequest) (*spannerpb.ResultSet, error) { 822 s.mu.Lock() 823 if s.stopped { 824 s.mu.Unlock() 825 return nil, gstatus.Error(codes.Unavailable, "server has been stopped") 826 } 827 s.receivedRequests <- req 828 s.mu.Unlock() 829 return nil, gstatus.Error(codes.Unimplemented, "Method not yet implemented") 830} 831 832func (s *inMemSpannerServer) StreamingRead(req *spannerpb.ReadRequest, stream spannerpb.Spanner_StreamingReadServer) error { 833 s.mu.Lock() 834 if s.stopped { 835 s.mu.Unlock() 836 return gstatus.Error(codes.Unavailable, "server has been stopped") 837 } 838 s.receivedRequests <- req 839 s.mu.Unlock() 840 return gstatus.Error(codes.Unimplemented, "Method not yet implemented") 841} 842 843func (s *inMemSpannerServer) BeginTransaction(ctx context.Context, req *spannerpb.BeginTransactionRequest) (*spannerpb.Transaction, error) { 844 if err := s.simulateExecutionTime(MethodBeginTransaction, req); err != nil { 845 return nil, err 846 } 847 if req.Session == "" { 848 return nil, gstatus.Error(codes.InvalidArgument, "Missing session name") 849 } 850 session, err := s.findSession(req.Session) 851 if err != nil { 852 return nil, err 853 } 854 s.updateSessionLastUseTime(session.Name) 855 tx := s.beginTransaction(session, req.Options) 856 return tx, nil 857} 858 859func (s *inMemSpannerServer) Commit(ctx context.Context, req *spannerpb.CommitRequest) (*spannerpb.CommitResponse, error) { 860 if err := s.simulateExecutionTime(MethodCommitTransaction, req); err != nil { 861 return nil, err 862 } 863 if req.Session == "" { 864 return nil, gstatus.Error(codes.InvalidArgument, "Missing session name") 865 } 866 session, err := s.findSession(req.Session) 867 if err != nil { 868 return nil, err 869 } 870 s.updateSessionLastUseTime(session.Name) 871 var tx *spannerpb.Transaction 872 if req.GetSingleUseTransaction() != nil { 873 tx = s.beginTransaction(session, req.GetSingleUseTransaction()) 874 } else if req.GetTransactionId() != nil { 875 tx, err = s.getTransactionByID(req.GetTransactionId()) 876 if err != nil { 877 return nil, err 878 } 879 } else { 880 return nil, gstatus.Error(codes.InvalidArgument, "Missing transaction in commit request") 881 } 882 s.removeTransaction(tx) 883 return &spannerpb.CommitResponse{CommitTimestamp: getCurrentTimestamp()}, nil 884} 885 886func (s *inMemSpannerServer) Rollback(ctx context.Context, req *spannerpb.RollbackRequest) (*emptypb.Empty, error) { 887 s.mu.Lock() 888 if s.stopped { 889 s.mu.Unlock() 890 return nil, gstatus.Error(codes.Unavailable, "server has been stopped") 891 } 892 s.receivedRequests <- req 893 s.mu.Unlock() 894 if req.Session == "" { 895 return nil, gstatus.Error(codes.InvalidArgument, "Missing session name") 896 } 897 session, err := s.findSession(req.Session) 898 if err != nil { 899 return nil, err 900 } 901 s.updateSessionLastUseTime(session.Name) 902 tx, err := s.getTransactionByID(req.TransactionId) 903 if err != nil { 904 return nil, err 905 } 906 s.removeTransaction(tx) 907 return &emptypb.Empty{}, nil 908} 909 910func (s *inMemSpannerServer) PartitionQuery(ctx context.Context, req *spannerpb.PartitionQueryRequest) (*spannerpb.PartitionResponse, error) { 911 s.mu.Lock() 912 if s.stopped { 913 s.mu.Unlock() 914 return nil, gstatus.Error(codes.Unavailable, "server has been stopped") 915 } 916 s.receivedRequests <- req 917 s.mu.Unlock() 918 return nil, gstatus.Error(codes.Unimplemented, "Method not yet implemented") 919} 920 921func (s *inMemSpannerServer) PartitionRead(ctx context.Context, req *spannerpb.PartitionReadRequest) (*spannerpb.PartitionResponse, error) { 922 s.mu.Lock() 923 if s.stopped { 924 s.mu.Unlock() 925 return nil, gstatus.Error(codes.Unavailable, "server has been stopped") 926 } 927 s.receivedRequests <- req 928 s.mu.Unlock() 929 return nil, gstatus.Error(codes.Unimplemented, "Method not yet implemented") 930} 931