1// Copyright 2019 Google LLC 2// 3// Licensed under the Apache License, Version 2.0 (the "License"); 4// you may not use this file except in compliance with the License. 5// You may obtain a copy of the License at 6// 7// https://www.apache.org/licenses/LICENSE-2.0 8// 9// Unless required by applicable law or agreed to in writing, software 10// distributed under the License is distributed on an "AS IS" BASIS, 11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12// See the License for the specific language governing permissions and 13// limitations under the License. 14 15package testutil 16 17import ( 18 "bytes" 19 "context" 20 "fmt" 21 "math/rand" 22 "sort" 23 "strings" 24 "sync" 25 "time" 26 27 "github.com/golang/protobuf/ptypes" 28 emptypb "github.com/golang/protobuf/ptypes/empty" 29 structpb "github.com/golang/protobuf/ptypes/struct" 30 "github.com/golang/protobuf/ptypes/timestamp" 31 "google.golang.org/genproto/googleapis/rpc/errdetails" 32 "google.golang.org/genproto/googleapis/rpc/status" 33 spannerpb "google.golang.org/genproto/googleapis/spanner/v1" 34 "google.golang.org/grpc/codes" 35 gstatus "google.golang.org/grpc/status" 36) 37 38// StatementResultType indicates the type of result returned by a SQL 39// statement. 40type StatementResultType int 41 42const ( 43 // StatementResultError indicates that the sql statement returns an error. 44 StatementResultError StatementResultType = 0 45 // StatementResultResultSet indicates that the sql statement returns a 46 // result set. 47 StatementResultResultSet StatementResultType = 1 48 // StatementResultUpdateCount indicates that the sql statement returns an 49 // update count. 50 StatementResultUpdateCount StatementResultType = 2 51 // MaxRowsPerPartialResultSet is the maximum number of rows returned in 52 // each PartialResultSet. This number is deliberately set to a low value to 53 // ensure that most queries return more than one PartialResultSet. 54 MaxRowsPerPartialResultSet = 1 55) 56 57// The method names that can be used to register execution times and errors. 58const ( 59 MethodBeginTransaction string = "BEGIN_TRANSACTION" 60 MethodCommitTransaction string = "COMMIT_TRANSACTION" 61 MethodBatchCreateSession string = "BATCH_CREATE_SESSION" 62 MethodCreateSession string = "CREATE_SESSION" 63 MethodDeleteSession string = "DELETE_SESSION" 64 MethodGetSession string = "GET_SESSION" 65 MethodExecuteSql string = "EXECUTE_SQL" 66 MethodExecuteStreamingSql string = "EXECUTE_STREAMING_SQL" 67 MethodExecuteBatchDml string = "EXECUTE_BATCH_DML" 68) 69 70// StatementResult represents a mocked result on the test server. The result is 71// either of: a ResultSet, an update count or an error. 72type StatementResult struct { 73 Type StatementResultType 74 Err error 75 ResultSet *spannerpb.ResultSet 76 UpdateCount int64 77} 78 79// PartialResultSetExecutionTime represents execution times and errors that 80// should be used when a PartialResult at the specified resume token is to 81// be returned. 82type PartialResultSetExecutionTime struct { 83 ResumeToken []byte 84 ExecutionTime time.Duration 85 Err error 86} 87 88// Converts a ResultSet to a PartialResultSet. This method is used to convert 89// a mocked result to a PartialResultSet when one of the streaming methods are 90// called. 91func (s *StatementResult) toPartialResultSets(resumeToken []byte) (result []*spannerpb.PartialResultSet, err error) { 92 var startIndex uint64 93 if len(resumeToken) > 0 { 94 if startIndex, err = DecodeResumeToken(resumeToken); err != nil { 95 return nil, err 96 } 97 } 98 99 totalRows := uint64(len(s.ResultSet.Rows)) 100 for { 101 rowCount := min(totalRows-startIndex, uint64(MaxRowsPerPartialResultSet)) 102 rows := s.ResultSet.Rows[startIndex : startIndex+rowCount] 103 values := make([]*structpb.Value, 104 len(rows)*len(s.ResultSet.Metadata.RowType.Fields)) 105 var idx int 106 for _, row := range rows { 107 for colIdx := range s.ResultSet.Metadata.RowType.Fields { 108 values[idx] = row.Values[colIdx] 109 idx++ 110 } 111 } 112 result = append(result, &spannerpb.PartialResultSet{ 113 Metadata: s.ResultSet.Metadata, 114 Values: values, 115 ResumeToken: EncodeResumeToken(startIndex + rowCount), 116 }) 117 startIndex += rowCount 118 if startIndex == totalRows { 119 break 120 } 121 } 122 return result, nil 123} 124 125func min(x, y uint64) uint64 { 126 if x > y { 127 return y 128 } 129 return x 130} 131 132func (s *StatementResult) updateCountToPartialResultSet(exact bool) *spannerpb.PartialResultSet { 133 return &spannerpb.PartialResultSet{ 134 Stats: s.convertUpdateCountToResultSet(exact).Stats, 135 } 136} 137 138// Converts an update count to a ResultSet, as DML statements also return the 139// update count as the statistics of a ResultSet. 140func (s *StatementResult) convertUpdateCountToResultSet(exact bool) *spannerpb.ResultSet { 141 if exact { 142 return &spannerpb.ResultSet{ 143 Stats: &spannerpb.ResultSetStats{ 144 RowCount: &spannerpb.ResultSetStats_RowCountExact{ 145 RowCountExact: s.UpdateCount, 146 }, 147 }, 148 } 149 } 150 return &spannerpb.ResultSet{ 151 Stats: &spannerpb.ResultSetStats{ 152 RowCount: &spannerpb.ResultSetStats_RowCountLowerBound{ 153 RowCountLowerBound: s.UpdateCount, 154 }, 155 }, 156 } 157} 158 159// SimulatedExecutionTime represents the time the execution of a method 160// should take, and any errors that should be returned by the method. 161type SimulatedExecutionTime struct { 162 MinimumExecutionTime time.Duration 163 RandomExecutionTime time.Duration 164 Errors []error 165 // Keep error after execution. The error will continue to be returned until 166 // it is cleared. 167 KeepError bool 168} 169 170// InMemSpannerServer contains the SpannerServer interface plus a couple 171// of specific methods for adding mocked results and resetting the server. 172type InMemSpannerServer interface { 173 spannerpb.SpannerServer 174 175 // Stops this server. 176 Stop() 177 178 // Resets the in-mem server to its default state, deleting all sessions and 179 // transactions that have been created on the server. Mocked results are 180 // not deleted. 181 Reset() 182 183 // Sets an error that will be returned by the next server call. The server 184 // call will also automatically clear the error. 185 SetError(err error) 186 187 // Puts a mocked result on the server for a specific sql statement. The 188 // server does not parse the SQL string in any way, it is merely used as 189 // a key to the mocked result. The result will be used for all methods that 190 // expect a SQL statement, including (batch) DML methods. 191 PutStatementResult(sql string, result *StatementResult) error 192 193 // Puts a mocked result on the server for a specific partition token. The 194 // result will only be used for query requests that specify a partition 195 // token. 196 PutPartitionResult(partitionToken []byte, result *StatementResult) error 197 198 // Adds a PartialResultSetExecutionTime to the server that should be returned 199 // for the specified SQL string. 200 AddPartialResultSetError(sql string, err PartialResultSetExecutionTime) 201 202 // Removes a mocked result on the server for a specific sql statement. 203 RemoveStatementResult(sql string) 204 205 // Aborts the specified transaction . This method can be used to test 206 // transaction retry logic. 207 AbortTransaction(id []byte) 208 209 // Puts a simulated execution time for one of the Spanner methods. 210 PutExecutionTime(method string, executionTime SimulatedExecutionTime) 211 // Freeze stalls all requests. 212 Freeze() 213 // Unfreeze restores processing requests. 214 Unfreeze() 215 216 TotalSessionsCreated() uint 217 TotalSessionsDeleted() uint 218 SetMaxSessionsReturnedByServerPerBatchRequest(sessionCount int32) 219 SetMaxSessionsReturnedByServerInTotal(sessionCount int32) 220 221 ReceivedRequests() chan interface{} 222 DumpSessions() map[string]bool 223 ClearPings() 224 DumpPings() []string 225} 226 227type inMemSpannerServer struct { 228 // Embed for forward compatibility. 229 // Tests will keep working if more methods are added 230 // in the future. 231 spannerpb.SpannerServer 232 233 mu sync.Mutex 234 // Set to true when this server been stopped. This is the end state of a 235 // server, a stopped server cannot be restarted. 236 stopped bool 237 // If set, all calls return this error. 238 err error 239 // The mock server creates session IDs using this counter. 240 sessionCounter uint64 241 // The sessions that have been created on this mock server. 242 sessions map[string]*spannerpb.Session 243 // Last use times per session. 244 sessionLastUseTime map[string]time.Time 245 // The mock server creates transaction IDs per session using these 246 // counters. 247 transactionCounters map[string]*uint64 248 // The transactions that have been created on this mock server. 249 transactions map[string]*spannerpb.Transaction 250 // The transactions that have been (manually) aborted on the server. 251 abortedTransactions map[string]bool 252 // The transactions that are marked as PartitionedDMLTransaction 253 partitionedDmlTransactions map[string]bool 254 // The mocked results for this server. 255 statementResults map[string]*StatementResult 256 partitionResults map[string]*StatementResult 257 // The simulated execution times per method. 258 executionTimes map[string]*SimulatedExecutionTime 259 // The simulated errors for partial result sets 260 partialResultSetErrors map[string][]*PartialResultSetExecutionTime 261 262 totalSessionsCreated uint 263 totalSessionsDeleted uint 264 // The maximum number of sessions that will be created per batch request. 265 maxSessionsReturnedByServerPerBatchRequest int32 266 maxSessionsReturnedByServerInTotal int32 267 receivedRequests chan interface{} 268 // Session ping history. 269 pings []string 270 271 // Server will stall on any requests. 272 freezed chan struct{} 273} 274 275// NewInMemSpannerServer creates a new in-mem test server. 276func NewInMemSpannerServer() InMemSpannerServer { 277 res := &inMemSpannerServer{} 278 res.initDefaults() 279 res.statementResults = make(map[string]*StatementResult) 280 res.partitionResults = make(map[string]*StatementResult) 281 res.executionTimes = make(map[string]*SimulatedExecutionTime) 282 res.partialResultSetErrors = make(map[string][]*PartialResultSetExecutionTime) 283 res.receivedRequests = make(chan interface{}, 1000000) 284 // Produce a closed channel, so the default action of ready is to not block. 285 res.Freeze() 286 res.Unfreeze() 287 return res 288} 289 290func (s *inMemSpannerServer) Stop() { 291 s.mu.Lock() 292 defer s.mu.Unlock() 293 s.stopped = true 294 close(s.receivedRequests) 295} 296 297// Resets the test server to its initial state, deleting all sessions and 298// transactions that have been created on the server. This method will not 299// remove mocked results. 300func (s *inMemSpannerServer) Reset() { 301 s.mu.Lock() 302 defer s.mu.Unlock() 303 close(s.receivedRequests) 304 s.receivedRequests = make(chan interface{}, 1000000) 305 s.initDefaults() 306} 307 308func (s *inMemSpannerServer) SetError(err error) { 309 s.mu.Lock() 310 defer s.mu.Unlock() 311 s.err = err 312} 313 314// Registers a mocked result for a SQL statement on the server. 315func (s *inMemSpannerServer) PutStatementResult(sql string, result *StatementResult) error { 316 s.mu.Lock() 317 defer s.mu.Unlock() 318 s.statementResults[sql] = result 319 return nil 320} 321 322func (s *inMemSpannerServer) RemoveStatementResult(sql string) { 323 s.mu.Lock() 324 defer s.mu.Unlock() 325 delete(s.statementResults, sql) 326} 327 328// Registers a mocked result for a partition token on the server. 329func (s *inMemSpannerServer) PutPartitionResult(partitionToken []byte, result *StatementResult) error { 330 tokenString := string(partitionToken) 331 s.mu.Lock() 332 defer s.mu.Unlock() 333 s.partitionResults[tokenString] = result 334 return nil 335} 336 337func (s *inMemSpannerServer) AbortTransaction(id []byte) { 338 s.mu.Lock() 339 defer s.mu.Unlock() 340 s.abortedTransactions[string(id)] = true 341} 342 343func (s *inMemSpannerServer) PutExecutionTime(method string, executionTime SimulatedExecutionTime) { 344 s.mu.Lock() 345 defer s.mu.Unlock() 346 s.executionTimes[method] = &executionTime 347} 348 349func (s *inMemSpannerServer) AddPartialResultSetError(sql string, partialResultSetError PartialResultSetExecutionTime) { 350 s.mu.Lock() 351 defer s.mu.Unlock() 352 s.partialResultSetErrors[sql] = append(s.partialResultSetErrors[sql], &partialResultSetError) 353} 354 355// Freeze stalls all requests. 356func (s *inMemSpannerServer) Freeze() { 357 s.mu.Lock() 358 defer s.mu.Unlock() 359 s.freezed = make(chan struct{}) 360} 361 362// Unfreeze restores processing requests. 363func (s *inMemSpannerServer) Unfreeze() { 364 s.mu.Lock() 365 defer s.mu.Unlock() 366 close(s.freezed) 367} 368 369// ready checks conditions before executing requests 370func (s *inMemSpannerServer) ready() { 371 s.mu.Lock() 372 freezed := s.freezed 373 s.mu.Unlock() 374 // check if server should be freezed 375 <-freezed 376} 377 378func (s *inMemSpannerServer) TotalSessionsCreated() uint { 379 s.mu.Lock() 380 defer s.mu.Unlock() 381 return s.totalSessionsCreated 382} 383 384func (s *inMemSpannerServer) TotalSessionsDeleted() uint { 385 s.mu.Lock() 386 defer s.mu.Unlock() 387 return s.totalSessionsDeleted 388} 389 390func (s *inMemSpannerServer) SetMaxSessionsReturnedByServerPerBatchRequest(sessionCount int32) { 391 s.mu.Lock() 392 defer s.mu.Unlock() 393 s.maxSessionsReturnedByServerPerBatchRequest = sessionCount 394} 395 396func (s *inMemSpannerServer) SetMaxSessionsReturnedByServerInTotal(sessionCount int32) { 397 s.mu.Lock() 398 defer s.mu.Unlock() 399 s.maxSessionsReturnedByServerInTotal = sessionCount 400} 401 402func (s *inMemSpannerServer) ReceivedRequests() chan interface{} { 403 return s.receivedRequests 404} 405 406// ClearPings clears the ping history from the server. 407func (s *inMemSpannerServer) ClearPings() { 408 s.mu.Lock() 409 defer s.mu.Unlock() 410 s.pings = nil 411} 412 413// DumpPings dumps the ping history. 414func (s *inMemSpannerServer) DumpPings() []string { 415 s.mu.Lock() 416 defer s.mu.Unlock() 417 return append([]string(nil), s.pings...) 418} 419 420// DumpSessions dumps the internal session table. 421func (s *inMemSpannerServer) DumpSessions() map[string]bool { 422 s.mu.Lock() 423 defer s.mu.Unlock() 424 st := map[string]bool{} 425 for s := range s.sessions { 426 st[s] = true 427 } 428 return st 429} 430 431func (s *inMemSpannerServer) initDefaults() { 432 s.sessionCounter = 0 433 s.maxSessionsReturnedByServerPerBatchRequest = 100 434 s.sessions = make(map[string]*spannerpb.Session) 435 s.sessionLastUseTime = make(map[string]time.Time) 436 s.transactions = make(map[string]*spannerpb.Transaction) 437 s.abortedTransactions = make(map[string]bool) 438 s.partitionedDmlTransactions = make(map[string]bool) 439 s.transactionCounters = make(map[string]*uint64) 440} 441 442func (s *inMemSpannerServer) generateSessionNameLocked(database string) string { 443 s.sessionCounter++ 444 return fmt.Sprintf("%s/sessions/%d", database, s.sessionCounter) 445} 446 447func (s *inMemSpannerServer) findSession(name string) (*spannerpb.Session, error) { 448 s.mu.Lock() 449 defer s.mu.Unlock() 450 session := s.sessions[name] 451 if session == nil { 452 return nil, newSessionNotFoundError(name) 453 } 454 return session, nil 455} 456 457// sessionResourceType is the type name of Spanner sessions. 458const sessionResourceType = "type.googleapis.com/google.spanner.v1.Session" 459 460func newSessionNotFoundError(name string) error { 461 s := gstatus.Newf(codes.NotFound, "Session not found: Session with id %s not found", name) 462 s, _ = s.WithDetails(&errdetails.ResourceInfo{ResourceType: sessionResourceType, ResourceName: name}) 463 return s.Err() 464} 465 466func (s *inMemSpannerServer) updateSessionLastUseTime(session string) { 467 s.mu.Lock() 468 defer s.mu.Unlock() 469 s.sessionLastUseTime[session] = time.Now() 470} 471 472func getCurrentTimestamp() *timestamp.Timestamp { 473 t := time.Now() 474 return ×tamp.Timestamp{Seconds: t.Unix(), Nanos: int32(t.Nanosecond())} 475} 476 477// Gets the transaction id from the transaction selector. If the selector 478// specifies that a new transaction should be started, this method will start 479// a new transaction and return the id of that transaction. 480func (s *inMemSpannerServer) getTransactionID(session *spannerpb.Session, txSelector *spannerpb.TransactionSelector) []byte { 481 var res []byte 482 if txSelector.GetBegin() != nil { 483 // Start a new transaction. 484 res = s.beginTransaction(session, txSelector.GetBegin()).Id 485 } else if txSelector.GetId() != nil { 486 res = txSelector.GetId() 487 } 488 return res 489} 490 491func (s *inMemSpannerServer) generateTransactionName(session string) string { 492 s.mu.Lock() 493 defer s.mu.Unlock() 494 counter, ok := s.transactionCounters[session] 495 if !ok { 496 counter = new(uint64) 497 s.transactionCounters[session] = counter 498 } 499 *counter++ 500 return fmt.Sprintf("%s/transactions/%d", session, *counter) 501} 502 503func (s *inMemSpannerServer) beginTransaction(session *spannerpb.Session, options *spannerpb.TransactionOptions) *spannerpb.Transaction { 504 id := s.generateTransactionName(session.Name) 505 res := &spannerpb.Transaction{ 506 Id: []byte(id), 507 ReadTimestamp: getCurrentTimestamp(), 508 } 509 s.mu.Lock() 510 s.transactions[id] = res 511 s.partitionedDmlTransactions[id] = options.GetPartitionedDml() != nil 512 s.mu.Unlock() 513 return res 514} 515 516func (s *inMemSpannerServer) getTransactionByID(id []byte) (*spannerpb.Transaction, error) { 517 s.mu.Lock() 518 defer s.mu.Unlock() 519 tx, ok := s.transactions[string(id)] 520 if !ok { 521 return nil, gstatus.Error(codes.NotFound, "Transaction not found") 522 } 523 aborted, ok := s.abortedTransactions[string(id)] 524 if ok && aborted { 525 return nil, newAbortedErrorWithMinimalRetryDelay() 526 } 527 return tx, nil 528} 529 530func newAbortedErrorWithMinimalRetryDelay() error { 531 st := gstatus.New(codes.Aborted, "Transaction has been aborted") 532 retry := &errdetails.RetryInfo{ 533 RetryDelay: ptypes.DurationProto(time.Nanosecond), 534 } 535 st, _ = st.WithDetails(retry) 536 return st.Err() 537} 538 539func (s *inMemSpannerServer) removeTransaction(tx *spannerpb.Transaction) { 540 s.mu.Lock() 541 defer s.mu.Unlock() 542 delete(s.transactions, string(tx.Id)) 543 delete(s.partitionedDmlTransactions, string(tx.Id)) 544} 545 546func (s *inMemSpannerServer) getPartitionResult(partitionToken []byte) (*StatementResult, error) { 547 tokenString := string(partitionToken) 548 s.mu.Lock() 549 defer s.mu.Unlock() 550 result, ok := s.partitionResults[tokenString] 551 if !ok { 552 return nil, gstatus.Error(codes.Internal, fmt.Sprintf("No result found for partition token %v", tokenString)) 553 } 554 return result, nil 555} 556 557func (s *inMemSpannerServer) getStatementResult(sql string) (*StatementResult, error) { 558 s.mu.Lock() 559 defer s.mu.Unlock() 560 result, ok := s.statementResults[sql] 561 if !ok { 562 return nil, gstatus.Error(codes.Internal, fmt.Sprintf("No result found for statement %v", sql)) 563 } 564 return result, nil 565} 566 567func (s *inMemSpannerServer) simulateExecutionTime(method string, req interface{}) error { 568 s.mu.Lock() 569 if s.stopped { 570 s.mu.Unlock() 571 return gstatus.Error(codes.Unavailable, "server has been stopped") 572 } 573 s.receivedRequests <- req 574 s.mu.Unlock() 575 s.ready() 576 s.mu.Lock() 577 if s.err != nil { 578 err := s.err 579 s.err = nil 580 s.mu.Unlock() 581 return err 582 } 583 executionTime, ok := s.executionTimes[method] 584 s.mu.Unlock() 585 if ok { 586 var randTime int64 587 if executionTime.RandomExecutionTime > 0 { 588 randTime = rand.Int63n(int64(executionTime.RandomExecutionTime)) 589 } 590 totalExecutionTime := time.Duration(int64(executionTime.MinimumExecutionTime) + randTime) 591 <-time.After(totalExecutionTime) 592 s.mu.Lock() 593 if executionTime.Errors != nil && len(executionTime.Errors) > 0 { 594 err := executionTime.Errors[0] 595 if !executionTime.KeepError { 596 executionTime.Errors = executionTime.Errors[1:] 597 } 598 s.mu.Unlock() 599 return err 600 } 601 s.mu.Unlock() 602 } 603 return nil 604} 605 606func (s *inMemSpannerServer) CreateSession(ctx context.Context, req *spannerpb.CreateSessionRequest) (*spannerpb.Session, error) { 607 if err := s.simulateExecutionTime(MethodCreateSession, req); err != nil { 608 return nil, err 609 } 610 if req.Database == "" { 611 return nil, gstatus.Error(codes.InvalidArgument, "Missing database") 612 } 613 s.mu.Lock() 614 defer s.mu.Unlock() 615 if s.maxSessionsReturnedByServerInTotal > int32(0) && int32(len(s.sessions)) == s.maxSessionsReturnedByServerInTotal { 616 return nil, gstatus.Error(codes.ResourceExhausted, "No more sessions available") 617 } 618 sessionName := s.generateSessionNameLocked(req.Database) 619 ts := getCurrentTimestamp() 620 session := &spannerpb.Session{Name: sessionName, CreateTime: ts, ApproximateLastUseTime: ts} 621 s.totalSessionsCreated++ 622 s.sessions[sessionName] = session 623 return session, nil 624} 625 626func (s *inMemSpannerServer) BatchCreateSessions(ctx context.Context, req *spannerpb.BatchCreateSessionsRequest) (*spannerpb.BatchCreateSessionsResponse, error) { 627 if err := s.simulateExecutionTime(MethodBatchCreateSession, req); err != nil { 628 return nil, err 629 } 630 if req.Database == "" { 631 return nil, gstatus.Error(codes.InvalidArgument, "Missing database") 632 } 633 if req.SessionCount <= 0 { 634 return nil, gstatus.Error(codes.InvalidArgument, "Session count must be >= 0") 635 } 636 sessionsToCreate := req.SessionCount 637 s.mu.Lock() 638 defer s.mu.Unlock() 639 if s.maxSessionsReturnedByServerInTotal > int32(0) && int32(len(s.sessions)) >= s.maxSessionsReturnedByServerInTotal { 640 return nil, gstatus.Error(codes.ResourceExhausted, "No more sessions available") 641 } 642 if sessionsToCreate > s.maxSessionsReturnedByServerPerBatchRequest { 643 sessionsToCreate = s.maxSessionsReturnedByServerPerBatchRequest 644 } 645 if s.maxSessionsReturnedByServerInTotal > int32(0) && (sessionsToCreate+int32(len(s.sessions))) > s.maxSessionsReturnedByServerInTotal { 646 sessionsToCreate = s.maxSessionsReturnedByServerInTotal - int32(len(s.sessions)) 647 } 648 sessions := make([]*spannerpb.Session, sessionsToCreate) 649 for i := int32(0); i < sessionsToCreate; i++ { 650 sessionName := s.generateSessionNameLocked(req.Database) 651 ts := getCurrentTimestamp() 652 sessions[i] = &spannerpb.Session{Name: sessionName, CreateTime: ts, ApproximateLastUseTime: ts} 653 s.totalSessionsCreated++ 654 s.sessions[sessionName] = sessions[i] 655 } 656 return &spannerpb.BatchCreateSessionsResponse{Session: sessions}, nil 657} 658 659func (s *inMemSpannerServer) GetSession(ctx context.Context, req *spannerpb.GetSessionRequest) (*spannerpb.Session, error) { 660 if err := s.simulateExecutionTime(MethodGetSession, req); err != nil { 661 return nil, err 662 } 663 if req.Name == "" { 664 return nil, gstatus.Error(codes.InvalidArgument, "Missing session name") 665 } 666 session, err := s.findSession(req.Name) 667 if err != nil { 668 return nil, err 669 } 670 return session, nil 671} 672 673func (s *inMemSpannerServer) ListSessions(ctx context.Context, req *spannerpb.ListSessionsRequest) (*spannerpb.ListSessionsResponse, error) { 674 s.mu.Lock() 675 if s.stopped { 676 s.mu.Unlock() 677 return nil, gstatus.Error(codes.Unavailable, "server has been stopped") 678 } 679 s.receivedRequests <- req 680 s.mu.Unlock() 681 if req.Database == "" { 682 return nil, gstatus.Error(codes.InvalidArgument, "Missing database") 683 } 684 expectedSessionName := req.Database + "/sessions/" 685 var sessions []*spannerpb.Session 686 s.mu.Lock() 687 for _, session := range s.sessions { 688 if strings.Index(session.Name, expectedSessionName) == 0 { 689 sessions = append(sessions, session) 690 } 691 } 692 s.mu.Unlock() 693 sort.Slice(sessions[:], func(i, j int) bool { 694 return sessions[i].Name < sessions[j].Name 695 }) 696 res := &spannerpb.ListSessionsResponse{Sessions: sessions} 697 return res, nil 698} 699 700func (s *inMemSpannerServer) DeleteSession(ctx context.Context, req *spannerpb.DeleteSessionRequest) (*emptypb.Empty, error) { 701 if err := s.simulateExecutionTime(MethodDeleteSession, req); err != nil { 702 return nil, err 703 } 704 if req.Name == "" { 705 return nil, gstatus.Error(codes.InvalidArgument, "Missing session name") 706 } 707 if _, err := s.findSession(req.Name); err != nil { 708 return nil, err 709 } 710 s.mu.Lock() 711 defer s.mu.Unlock() 712 s.totalSessionsDeleted++ 713 delete(s.sessions, req.Name) 714 return &emptypb.Empty{}, nil 715} 716 717func (s *inMemSpannerServer) ExecuteSql(ctx context.Context, req *spannerpb.ExecuteSqlRequest) (*spannerpb.ResultSet, error) { 718 if err := s.simulateExecutionTime(MethodExecuteSql, req); err != nil { 719 return nil, err 720 } 721 if req.Sql == "SELECT 1" { 722 s.mu.Lock() 723 s.pings = append(s.pings, req.Session) 724 s.mu.Unlock() 725 } 726 if req.Session == "" { 727 return nil, gstatus.Error(codes.InvalidArgument, "Missing session name") 728 } 729 session, err := s.findSession(req.Session) 730 if err != nil { 731 return nil, err 732 } 733 var id []byte 734 s.updateSessionLastUseTime(session.Name) 735 if id = s.getTransactionID(session, req.Transaction); id != nil { 736 _, err = s.getTransactionByID(id) 737 if err != nil { 738 return nil, err 739 } 740 } 741 var statementResult *StatementResult 742 if req.PartitionToken != nil { 743 statementResult, err = s.getPartitionResult(req.PartitionToken) 744 } else { 745 statementResult, err = s.getStatementResult(req.Sql) 746 } 747 if err != nil { 748 return nil, err 749 } 750 s.mu.Lock() 751 isPartitionedDml := s.partitionedDmlTransactions[string(id)] 752 s.mu.Unlock() 753 switch statementResult.Type { 754 case StatementResultError: 755 return nil, statementResult.Err 756 case StatementResultResultSet: 757 return statementResult.ResultSet, nil 758 case StatementResultUpdateCount: 759 return statementResult.convertUpdateCountToResultSet(!isPartitionedDml), nil 760 } 761 return nil, gstatus.Error(codes.Internal, "Unknown result type") 762} 763 764func (s *inMemSpannerServer) ExecuteStreamingSql(req *spannerpb.ExecuteSqlRequest, stream spannerpb.Spanner_ExecuteStreamingSqlServer) error { 765 if err := s.simulateExecutionTime(MethodExecuteStreamingSql, req); err != nil { 766 return err 767 } 768 if req.Session == "" { 769 return gstatus.Error(codes.InvalidArgument, "Missing session name") 770 } 771 session, err := s.findSession(req.Session) 772 if err != nil { 773 return err 774 } 775 s.updateSessionLastUseTime(session.Name) 776 var id []byte 777 if id = s.getTransactionID(session, req.Transaction); id != nil { 778 _, err = s.getTransactionByID(id) 779 if err != nil { 780 return err 781 } 782 } 783 var statementResult *StatementResult 784 if req.PartitionToken != nil { 785 statementResult, err = s.getPartitionResult(req.PartitionToken) 786 } else { 787 statementResult, err = s.getStatementResult(req.Sql) 788 } 789 if err != nil { 790 return err 791 } 792 s.mu.Lock() 793 isPartitionedDml := s.partitionedDmlTransactions[string(id)] 794 s.mu.Unlock() 795 switch statementResult.Type { 796 case StatementResultError: 797 return statementResult.Err 798 case StatementResultResultSet: 799 parts, err := statementResult.toPartialResultSets(req.ResumeToken) 800 if err != nil { 801 return err 802 } 803 var nextPartialResultSetError *PartialResultSetExecutionTime 804 s.mu.Lock() 805 pErrors := s.partialResultSetErrors[req.Sql] 806 if len(pErrors) > 0 { 807 nextPartialResultSetError = pErrors[0] 808 s.partialResultSetErrors[req.Sql] = pErrors[1:] 809 } 810 s.mu.Unlock() 811 for _, part := range parts { 812 if nextPartialResultSetError != nil && bytes.Equal(part.ResumeToken, nextPartialResultSetError.ResumeToken) { 813 if nextPartialResultSetError.ExecutionTime > 0 { 814 <-time.After(nextPartialResultSetError.ExecutionTime) 815 } 816 if nextPartialResultSetError.Err != nil { 817 return nextPartialResultSetError.Err 818 } 819 } 820 if err := stream.Send(part); err != nil { 821 return err 822 } 823 } 824 return nil 825 case StatementResultUpdateCount: 826 part := statementResult.updateCountToPartialResultSet(!isPartitionedDml) 827 if err := stream.Send(part); err != nil { 828 return err 829 } 830 return nil 831 } 832 return gstatus.Error(codes.Internal, "Unknown result type") 833} 834 835func (s *inMemSpannerServer) ExecuteBatchDml(ctx context.Context, req *spannerpb.ExecuteBatchDmlRequest) (*spannerpb.ExecuteBatchDmlResponse, error) { 836 if err := s.simulateExecutionTime(MethodExecuteBatchDml, req); err != nil { 837 return nil, err 838 } 839 if req.Session == "" { 840 return nil, gstatus.Error(codes.InvalidArgument, "Missing session name") 841 } 842 session, err := s.findSession(req.Session) 843 if err != nil { 844 return nil, err 845 } 846 s.updateSessionLastUseTime(session.Name) 847 var id []byte 848 if id = s.getTransactionID(session, req.Transaction); id != nil { 849 _, err = s.getTransactionByID(id) 850 if err != nil { 851 return nil, err 852 } 853 } 854 s.mu.Lock() 855 isPartitionedDml := s.partitionedDmlTransactions[string(id)] 856 s.mu.Unlock() 857 resp := &spannerpb.ExecuteBatchDmlResponse{} 858 resp.ResultSets = make([]*spannerpb.ResultSet, len(req.Statements)) 859 resp.Status = &status.Status{Code: int32(codes.OK)} 860 for idx, batchStatement := range req.Statements { 861 statementResult, err := s.getStatementResult(batchStatement.Sql) 862 if err != nil { 863 return nil, err 864 } 865 switch statementResult.Type { 866 case StatementResultError: 867 resp.Status = &status.Status{Code: int32(gstatus.Code(statementResult.Err)), Message: statementResult.Err.Error()} 868 resp.ResultSets = resp.ResultSets[:idx] 869 return resp, nil 870 case StatementResultResultSet: 871 return nil, gstatus.Error(codes.InvalidArgument, fmt.Sprintf("Not an update statement: %v", batchStatement.Sql)) 872 case StatementResultUpdateCount: 873 resp.ResultSets[idx] = statementResult.convertUpdateCountToResultSet(!isPartitionedDml) 874 } 875 } 876 return resp, nil 877} 878 879func (s *inMemSpannerServer) Read(ctx context.Context, req *spannerpb.ReadRequest) (*spannerpb.ResultSet, error) { 880 s.mu.Lock() 881 if s.stopped { 882 s.mu.Unlock() 883 return nil, gstatus.Error(codes.Unavailable, "server has been stopped") 884 } 885 s.receivedRequests <- req 886 s.mu.Unlock() 887 return nil, gstatus.Error(codes.Unimplemented, "Method not yet implemented") 888} 889 890func (s *inMemSpannerServer) StreamingRead(req *spannerpb.ReadRequest, stream spannerpb.Spanner_StreamingReadServer) error { 891 s.mu.Lock() 892 if s.stopped { 893 s.mu.Unlock() 894 return gstatus.Error(codes.Unavailable, "server has been stopped") 895 } 896 s.receivedRequests <- req 897 s.mu.Unlock() 898 return gstatus.Error(codes.Unimplemented, "Method not yet implemented") 899} 900 901func (s *inMemSpannerServer) BeginTransaction(ctx context.Context, req *spannerpb.BeginTransactionRequest) (*spannerpb.Transaction, error) { 902 if err := s.simulateExecutionTime(MethodBeginTransaction, req); err != nil { 903 return nil, err 904 } 905 if req.Session == "" { 906 return nil, gstatus.Error(codes.InvalidArgument, "Missing session name") 907 } 908 session, err := s.findSession(req.Session) 909 if err != nil { 910 return nil, err 911 } 912 s.updateSessionLastUseTime(session.Name) 913 tx := s.beginTransaction(session, req.Options) 914 return tx, nil 915} 916 917func (s *inMemSpannerServer) Commit(ctx context.Context, req *spannerpb.CommitRequest) (*spannerpb.CommitResponse, error) { 918 if err := s.simulateExecutionTime(MethodCommitTransaction, req); err != nil { 919 return nil, err 920 } 921 if req.Session == "" { 922 return nil, gstatus.Error(codes.InvalidArgument, "Missing session name") 923 } 924 session, err := s.findSession(req.Session) 925 if err != nil { 926 return nil, err 927 } 928 s.updateSessionLastUseTime(session.Name) 929 var tx *spannerpb.Transaction 930 if req.GetSingleUseTransaction() != nil { 931 tx = s.beginTransaction(session, req.GetSingleUseTransaction()) 932 } else if req.GetTransactionId() != nil { 933 tx, err = s.getTransactionByID(req.GetTransactionId()) 934 if err != nil { 935 return nil, err 936 } 937 } else { 938 return nil, gstatus.Error(codes.InvalidArgument, "Missing transaction in commit request") 939 } 940 s.removeTransaction(tx) 941 return &spannerpb.CommitResponse{CommitTimestamp: getCurrentTimestamp()}, nil 942} 943 944func (s *inMemSpannerServer) Rollback(ctx context.Context, req *spannerpb.RollbackRequest) (*emptypb.Empty, error) { 945 s.mu.Lock() 946 if s.stopped { 947 s.mu.Unlock() 948 return nil, gstatus.Error(codes.Unavailable, "server has been stopped") 949 } 950 s.receivedRequests <- req 951 s.mu.Unlock() 952 if req.Session == "" { 953 return nil, gstatus.Error(codes.InvalidArgument, "Missing session name") 954 } 955 session, err := s.findSession(req.Session) 956 if err != nil { 957 return nil, err 958 } 959 s.updateSessionLastUseTime(session.Name) 960 tx, err := s.getTransactionByID(req.TransactionId) 961 if err != nil { 962 return nil, err 963 } 964 s.removeTransaction(tx) 965 return &emptypb.Empty{}, nil 966} 967 968func (s *inMemSpannerServer) PartitionQuery(ctx context.Context, req *spannerpb.PartitionQueryRequest) (*spannerpb.PartitionResponse, error) { 969 s.mu.Lock() 970 if s.stopped { 971 s.mu.Unlock() 972 return nil, gstatus.Error(codes.Unavailable, "server has been stopped") 973 } 974 s.receivedRequests <- req 975 s.mu.Unlock() 976 if req.Session == "" { 977 return nil, gstatus.Error(codes.InvalidArgument, "Missing session name") 978 } 979 session, err := s.findSession(req.Session) 980 if err != nil { 981 return nil, err 982 } 983 var id []byte 984 var tx *spannerpb.Transaction 985 s.updateSessionLastUseTime(session.Name) 986 if id = s.getTransactionID(session, req.Transaction); id != nil { 987 tx, err = s.getTransactionByID(id) 988 if err != nil { 989 return nil, err 990 } 991 } 992 var partitions []*spannerpb.Partition 993 for i := int64(0); i < req.PartitionOptions.MaxPartitions; i++ { 994 token := make([]byte, 10) 995 _, err := rand.Read(token) 996 if err != nil { 997 return nil, gstatus.Error(codes.Internal, "failed to generate random partition token") 998 } 999 partitions = append(partitions, &spannerpb.Partition{PartitionToken: token}) 1000 } 1001 return &spannerpb.PartitionResponse{ 1002 Partitions: partitions, 1003 Transaction: tx, 1004 }, nil 1005} 1006 1007func (s *inMemSpannerServer) PartitionRead(ctx context.Context, req *spannerpb.PartitionReadRequest) (*spannerpb.PartitionResponse, error) { 1008 s.mu.Lock() 1009 if s.stopped { 1010 s.mu.Unlock() 1011 return nil, gstatus.Error(codes.Unavailable, "server has been stopped") 1012 } 1013 s.receivedRequests <- req 1014 s.mu.Unlock() 1015 return nil, gstatus.Error(codes.Unimplemented, "Method not yet implemented") 1016} 1017