1// Copyright 2019 Google LLC 2// 3// Licensed under the Apache License, Version 2.0 (the "License"); 4// you may not use this file except in compliance with the License. 5// You may obtain a copy of the License at 6// 7// https://www.apache.org/licenses/LICENSE-2.0 8// 9// Unless required by applicable law or agreed to in writing, software 10// distributed under the License is distributed on an "AS IS" BASIS, 11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12// See the License for the specific language governing permissions and 13// limitations under the License. 14 15package testutil 16 17import ( 18 "bytes" 19 "context" 20 "fmt" 21 "math/rand" 22 "sort" 23 "strings" 24 "sync" 25 "time" 26 27 "github.com/golang/protobuf/ptypes" 28 emptypb "github.com/golang/protobuf/ptypes/empty" 29 structpb "github.com/golang/protobuf/ptypes/struct" 30 "github.com/golang/protobuf/ptypes/timestamp" 31 "google.golang.org/genproto/googleapis/rpc/errdetails" 32 "google.golang.org/genproto/googleapis/rpc/status" 33 spannerpb "google.golang.org/genproto/googleapis/spanner/v1" 34 "google.golang.org/grpc/codes" 35 gstatus "google.golang.org/grpc/status" 36) 37 38// StatementResultType indicates the type of result returned by a SQL 39// statement. 40type StatementResultType int 41 42const ( 43 // StatementResultError indicates that the sql statement returns an error. 44 StatementResultError StatementResultType = 0 45 // StatementResultResultSet indicates that the sql statement returns a 46 // result set. 47 StatementResultResultSet StatementResultType = 1 48 // StatementResultUpdateCount indicates that the sql statement returns an 49 // update count. 50 StatementResultUpdateCount StatementResultType = 2 51 // MaxRowsPerPartialResultSet is the maximum number of rows returned in 52 // each PartialResultSet. This number is deliberately set to a low value to 53 // ensure that most queries return more than one PartialResultSet. 54 MaxRowsPerPartialResultSet = 1 55) 56 57// The method names that can be used to register execution times and errors. 58const ( 59 MethodBeginTransaction string = "BEGIN_TRANSACTION" 60 MethodCommitTransaction string = "COMMIT_TRANSACTION" 61 MethodBatchCreateSession string = "BATCH_CREATE_SESSION" 62 MethodCreateSession string = "CREATE_SESSION" 63 MethodDeleteSession string = "DELETE_SESSION" 64 MethodGetSession string = "GET_SESSION" 65 MethodExecuteSql string = "EXECUTE_SQL" 66 MethodExecuteStreamingSql string = "EXECUTE_STREAMING_SQL" 67 MethodExecuteBatchDml string = "EXECUTE_BATCH_DML" 68) 69 70// StatementResult represents a mocked result on the test server. The result is 71// either of: a ResultSet, an update count or an error. 72type StatementResult struct { 73 Type StatementResultType 74 Err error 75 ResultSet *spannerpb.ResultSet 76 UpdateCount int64 77} 78 79// PartialResultSetExecutionTime represents execution times and errors that 80// should be used when a PartialResult at the specified resume token is to 81// be returned. 82type PartialResultSetExecutionTime struct { 83 ResumeToken []byte 84 ExecutionTime time.Duration 85 Err error 86} 87 88// Converts a ResultSet to a PartialResultSet. This method is used to convert 89// a mocked result to a PartialResultSet when one of the streaming methods are 90// called. 91func (s *StatementResult) toPartialResultSets(resumeToken []byte) (result []*spannerpb.PartialResultSet, err error) { 92 var startIndex uint64 93 if len(resumeToken) > 0 { 94 if startIndex, err = DecodeResumeToken(resumeToken); err != nil { 95 return nil, err 96 } 97 } 98 99 totalRows := uint64(len(s.ResultSet.Rows)) 100 for { 101 rowCount := min(totalRows-startIndex, uint64(MaxRowsPerPartialResultSet)) 102 rows := s.ResultSet.Rows[startIndex : startIndex+rowCount] 103 values := make([]*structpb.Value, 104 len(rows)*len(s.ResultSet.Metadata.RowType.Fields)) 105 var idx int 106 for _, row := range rows { 107 for colIdx := range s.ResultSet.Metadata.RowType.Fields { 108 values[idx] = row.Values[colIdx] 109 idx++ 110 } 111 } 112 result = append(result, &spannerpb.PartialResultSet{ 113 Metadata: s.ResultSet.Metadata, 114 Values: values, 115 ResumeToken: EncodeResumeToken(startIndex + rowCount), 116 }) 117 startIndex += rowCount 118 if startIndex == totalRows { 119 break 120 } 121 } 122 return result, nil 123} 124 125func min(x, y uint64) uint64 { 126 if x > y { 127 return y 128 } 129 return x 130} 131 132func (s *StatementResult) updateCountToPartialResultSet(exact bool) *spannerpb.PartialResultSet { 133 return &spannerpb.PartialResultSet{ 134 Stats: s.convertUpdateCountToResultSet(exact).Stats, 135 } 136} 137 138// Converts an update count to a ResultSet, as DML statements also return the 139// update count as the statistics of a ResultSet. 140func (s *StatementResult) convertUpdateCountToResultSet(exact bool) *spannerpb.ResultSet { 141 if exact { 142 return &spannerpb.ResultSet{ 143 Stats: &spannerpb.ResultSetStats{ 144 RowCount: &spannerpb.ResultSetStats_RowCountExact{ 145 RowCountExact: s.UpdateCount, 146 }, 147 }, 148 } 149 } 150 return &spannerpb.ResultSet{ 151 Stats: &spannerpb.ResultSetStats{ 152 RowCount: &spannerpb.ResultSetStats_RowCountLowerBound{ 153 RowCountLowerBound: s.UpdateCount, 154 }, 155 }, 156 } 157} 158 159// SimulatedExecutionTime represents the time the execution of a method 160// should take, and any errors that should be returned by the method. 161type SimulatedExecutionTime struct { 162 MinimumExecutionTime time.Duration 163 RandomExecutionTime time.Duration 164 Errors []error 165 // Keep error after execution. The error will continue to be returned until 166 // it is cleared. 167 KeepError bool 168} 169 170// InMemSpannerServer contains the SpannerServer interface plus a couple 171// of specific methods for adding mocked results and resetting the server. 172type InMemSpannerServer interface { 173 spannerpb.SpannerServer 174 175 // Stops this server. 176 Stop() 177 178 // Resets the in-mem server to its default state, deleting all sessions and 179 // transactions that have been created on the server. Mocked results are 180 // not deleted. 181 Reset() 182 183 // Sets an error that will be returned by the next server call. The server 184 // call will also automatically clear the error. 185 SetError(err error) 186 187 // Puts a mocked result on the server for a specific sql statement. The 188 // server does not parse the SQL string in any way, it is merely used as 189 // a key to the mocked result. The result will be used for all methods that 190 // expect a SQL statement, including (batch) DML methods. 191 PutStatementResult(sql string, result *StatementResult) error 192 193 // Adds a PartialResultSetExecutionTime to the server that should be returned 194 // for the specified SQL string. 195 AddPartialResultSetError(sql string, err PartialResultSetExecutionTime) 196 197 // Removes a mocked result on the server for a specific sql statement. 198 RemoveStatementResult(sql string) 199 200 // Aborts the specified transaction . This method can be used to test 201 // transaction retry logic. 202 AbortTransaction(id []byte) 203 204 // Puts a simulated execution time for one of the Spanner methods. 205 PutExecutionTime(method string, executionTime SimulatedExecutionTime) 206 // Freeze stalls all requests. 207 Freeze() 208 // Unfreeze restores processing requests. 209 Unfreeze() 210 211 TotalSessionsCreated() uint 212 TotalSessionsDeleted() uint 213 SetMaxSessionsReturnedByServerPerBatchRequest(sessionCount int32) 214 SetMaxSessionsReturnedByServerInTotal(sessionCount int32) 215 216 ReceivedRequests() chan interface{} 217 DumpSessions() map[string]bool 218 ClearPings() 219 DumpPings() []string 220} 221 222type inMemSpannerServer struct { 223 // Embed for forward compatibility. 224 // Tests will keep working if more methods are added 225 // in the future. 226 spannerpb.SpannerServer 227 228 mu sync.Mutex 229 // Set to true when this server been stopped. This is the end state of a 230 // server, a stopped server cannot be restarted. 231 stopped bool 232 // If set, all calls return this error. 233 err error 234 // The mock server creates session IDs using this counter. 235 sessionCounter uint64 236 // The sessions that have been created on this mock server. 237 sessions map[string]*spannerpb.Session 238 // Last use times per session. 239 sessionLastUseTime map[string]time.Time 240 // The mock server creates transaction IDs per session using these 241 // counters. 242 transactionCounters map[string]*uint64 243 // The transactions that have been created on this mock server. 244 transactions map[string]*spannerpb.Transaction 245 // The transactions that have been (manually) aborted on the server. 246 abortedTransactions map[string]bool 247 // The transactions that are marked as PartitionedDMLTransaction 248 partitionedDmlTransactions map[string]bool 249 // The mocked results for this server. 250 statementResults map[string]*StatementResult 251 // The simulated execution times per method. 252 executionTimes map[string]*SimulatedExecutionTime 253 // The simulated errors for partial result sets 254 partialResultSetErrors map[string][]*PartialResultSetExecutionTime 255 256 totalSessionsCreated uint 257 totalSessionsDeleted uint 258 // The maximum number of sessions that will be created per batch request. 259 maxSessionsReturnedByServerPerBatchRequest int32 260 maxSessionsReturnedByServerInTotal int32 261 receivedRequests chan interface{} 262 // Session ping history. 263 pings []string 264 265 // Server will stall on any requests. 266 freezed chan struct{} 267} 268 269// NewInMemSpannerServer creates a new in-mem test server. 270func NewInMemSpannerServer() InMemSpannerServer { 271 res := &inMemSpannerServer{} 272 res.initDefaults() 273 res.statementResults = make(map[string]*StatementResult) 274 res.executionTimes = make(map[string]*SimulatedExecutionTime) 275 res.partialResultSetErrors = make(map[string][]*PartialResultSetExecutionTime) 276 res.receivedRequests = make(chan interface{}, 1000000) 277 // Produce a closed channel, so the default action of ready is to not block. 278 res.Freeze() 279 res.Unfreeze() 280 return res 281} 282 283func (s *inMemSpannerServer) Stop() { 284 s.mu.Lock() 285 defer s.mu.Unlock() 286 s.stopped = true 287 close(s.receivedRequests) 288} 289 290// Resets the test server to its initial state, deleting all sessions and 291// transactions that have been created on the server. This method will not 292// remove mocked results. 293func (s *inMemSpannerServer) Reset() { 294 s.mu.Lock() 295 defer s.mu.Unlock() 296 close(s.receivedRequests) 297 s.receivedRequests = make(chan interface{}, 1000000) 298 s.initDefaults() 299} 300 301func (s *inMemSpannerServer) SetError(err error) { 302 s.mu.Lock() 303 defer s.mu.Unlock() 304 s.err = err 305} 306 307// Registers a mocked result for a SQL statement on the server. 308func (s *inMemSpannerServer) PutStatementResult(sql string, result *StatementResult) error { 309 s.mu.Lock() 310 defer s.mu.Unlock() 311 s.statementResults[sql] = result 312 return nil 313} 314 315func (s *inMemSpannerServer) RemoveStatementResult(sql string) { 316 s.mu.Lock() 317 defer s.mu.Unlock() 318 delete(s.statementResults, sql) 319} 320 321func (s *inMemSpannerServer) AbortTransaction(id []byte) { 322 s.mu.Lock() 323 defer s.mu.Unlock() 324 s.abortedTransactions[string(id)] = true 325} 326 327func (s *inMemSpannerServer) PutExecutionTime(method string, executionTime SimulatedExecutionTime) { 328 s.mu.Lock() 329 defer s.mu.Unlock() 330 s.executionTimes[method] = &executionTime 331} 332 333func (s *inMemSpannerServer) AddPartialResultSetError(sql string, partialResultSetError PartialResultSetExecutionTime) { 334 s.mu.Lock() 335 defer s.mu.Unlock() 336 s.partialResultSetErrors[sql] = append(s.partialResultSetErrors[sql], &partialResultSetError) 337} 338 339// Freeze stalls all requests. 340func (s *inMemSpannerServer) Freeze() { 341 s.mu.Lock() 342 defer s.mu.Unlock() 343 s.freezed = make(chan struct{}) 344} 345 346// Unfreeze restores processing requests. 347func (s *inMemSpannerServer) Unfreeze() { 348 s.mu.Lock() 349 defer s.mu.Unlock() 350 close(s.freezed) 351} 352 353// ready checks conditions before executing requests 354func (s *inMemSpannerServer) ready() { 355 s.mu.Lock() 356 freezed := s.freezed 357 s.mu.Unlock() 358 // check if server should be freezed 359 <-freezed 360} 361 362func (s *inMemSpannerServer) TotalSessionsCreated() uint { 363 s.mu.Lock() 364 defer s.mu.Unlock() 365 return s.totalSessionsCreated 366} 367 368func (s *inMemSpannerServer) TotalSessionsDeleted() uint { 369 s.mu.Lock() 370 defer s.mu.Unlock() 371 return s.totalSessionsDeleted 372} 373 374func (s *inMemSpannerServer) SetMaxSessionsReturnedByServerPerBatchRequest(sessionCount int32) { 375 s.mu.Lock() 376 defer s.mu.Unlock() 377 s.maxSessionsReturnedByServerPerBatchRequest = sessionCount 378} 379 380func (s *inMemSpannerServer) SetMaxSessionsReturnedByServerInTotal(sessionCount int32) { 381 s.mu.Lock() 382 defer s.mu.Unlock() 383 s.maxSessionsReturnedByServerInTotal = sessionCount 384} 385 386func (s *inMemSpannerServer) ReceivedRequests() chan interface{} { 387 return s.receivedRequests 388} 389 390// ClearPings clears the ping history from the server. 391func (s *inMemSpannerServer) ClearPings() { 392 s.mu.Lock() 393 defer s.mu.Unlock() 394 s.pings = nil 395} 396 397// DumpPings dumps the ping history. 398func (s *inMemSpannerServer) DumpPings() []string { 399 s.mu.Lock() 400 defer s.mu.Unlock() 401 return append([]string(nil), s.pings...) 402} 403 404// DumpSessions dumps the internal session table. 405func (s *inMemSpannerServer) DumpSessions() map[string]bool { 406 s.mu.Lock() 407 defer s.mu.Unlock() 408 st := map[string]bool{} 409 for s := range s.sessions { 410 st[s] = true 411 } 412 return st 413} 414 415func (s *inMemSpannerServer) initDefaults() { 416 s.sessionCounter = 0 417 s.maxSessionsReturnedByServerPerBatchRequest = 100 418 s.sessions = make(map[string]*spannerpb.Session) 419 s.sessionLastUseTime = make(map[string]time.Time) 420 s.transactions = make(map[string]*spannerpb.Transaction) 421 s.abortedTransactions = make(map[string]bool) 422 s.partitionedDmlTransactions = make(map[string]bool) 423 s.transactionCounters = make(map[string]*uint64) 424} 425 426func (s *inMemSpannerServer) generateSessionNameLocked(database string) string { 427 s.sessionCounter++ 428 return fmt.Sprintf("%s/sessions/%d", database, s.sessionCounter) 429} 430 431func (s *inMemSpannerServer) findSession(name string) (*spannerpb.Session, error) { 432 s.mu.Lock() 433 defer s.mu.Unlock() 434 session := s.sessions[name] 435 if session == nil { 436 return nil, newSessionNotFoundError(name) 437 } 438 return session, nil 439} 440 441// sessionResourceType is the type name of Spanner sessions. 442const sessionResourceType = "type.googleapis.com/google.spanner.v1.Session" 443 444func newSessionNotFoundError(name string) error { 445 s := gstatus.Newf(codes.NotFound, "Session not found: Session with id %s not found", name) 446 s, _ = s.WithDetails(&errdetails.ResourceInfo{ResourceType: sessionResourceType, ResourceName: name}) 447 return s.Err() 448} 449 450func (s *inMemSpannerServer) updateSessionLastUseTime(session string) { 451 s.mu.Lock() 452 defer s.mu.Unlock() 453 s.sessionLastUseTime[session] = time.Now() 454} 455 456func getCurrentTimestamp() *timestamp.Timestamp { 457 t := time.Now() 458 return ×tamp.Timestamp{Seconds: t.Unix(), Nanos: int32(t.Nanosecond())} 459} 460 461// Gets the transaction id from the transaction selector. If the selector 462// specifies that a new transaction should be started, this method will start 463// a new transaction and return the id of that transaction. 464func (s *inMemSpannerServer) getTransactionID(session *spannerpb.Session, txSelector *spannerpb.TransactionSelector) []byte { 465 var res []byte 466 if txSelector.GetBegin() != nil { 467 // Start a new transaction. 468 res = s.beginTransaction(session, txSelector.GetBegin()).Id 469 } else if txSelector.GetId() != nil { 470 res = txSelector.GetId() 471 } 472 return res 473} 474 475func (s *inMemSpannerServer) generateTransactionName(session string) string { 476 s.mu.Lock() 477 defer s.mu.Unlock() 478 counter, ok := s.transactionCounters[session] 479 if !ok { 480 counter = new(uint64) 481 s.transactionCounters[session] = counter 482 } 483 *counter++ 484 return fmt.Sprintf("%s/transactions/%d", session, *counter) 485} 486 487func (s *inMemSpannerServer) beginTransaction(session *spannerpb.Session, options *spannerpb.TransactionOptions) *spannerpb.Transaction { 488 id := s.generateTransactionName(session.Name) 489 res := &spannerpb.Transaction{ 490 Id: []byte(id), 491 ReadTimestamp: getCurrentTimestamp(), 492 } 493 s.mu.Lock() 494 s.transactions[id] = res 495 s.partitionedDmlTransactions[id] = options.GetPartitionedDml() != nil 496 s.mu.Unlock() 497 return res 498} 499 500func (s *inMemSpannerServer) getTransactionByID(id []byte) (*spannerpb.Transaction, error) { 501 s.mu.Lock() 502 defer s.mu.Unlock() 503 tx, ok := s.transactions[string(id)] 504 if !ok { 505 return nil, gstatus.Error(codes.NotFound, "Transaction not found") 506 } 507 aborted, ok := s.abortedTransactions[string(id)] 508 if ok && aborted { 509 return nil, newAbortedErrorWithMinimalRetryDelay() 510 } 511 return tx, nil 512} 513 514func newAbortedErrorWithMinimalRetryDelay() error { 515 st := gstatus.New(codes.Aborted, "Transaction has been aborted") 516 retry := &errdetails.RetryInfo{ 517 RetryDelay: ptypes.DurationProto(time.Nanosecond), 518 } 519 st, _ = st.WithDetails(retry) 520 return st.Err() 521} 522 523func (s *inMemSpannerServer) removeTransaction(tx *spannerpb.Transaction) { 524 s.mu.Lock() 525 defer s.mu.Unlock() 526 delete(s.transactions, string(tx.Id)) 527 delete(s.partitionedDmlTransactions, string(tx.Id)) 528} 529 530func (s *inMemSpannerServer) getStatementResult(sql string) (*StatementResult, error) { 531 s.mu.Lock() 532 defer s.mu.Unlock() 533 result, ok := s.statementResults[sql] 534 if !ok { 535 return nil, gstatus.Error(codes.Internal, fmt.Sprintf("No result found for statement %v", sql)) 536 } 537 return result, nil 538} 539 540func (s *inMemSpannerServer) simulateExecutionTime(method string, req interface{}) error { 541 s.mu.Lock() 542 if s.stopped { 543 s.mu.Unlock() 544 return gstatus.Error(codes.Unavailable, "server has been stopped") 545 } 546 s.receivedRequests <- req 547 s.mu.Unlock() 548 s.ready() 549 s.mu.Lock() 550 if s.err != nil { 551 err := s.err 552 s.err = nil 553 s.mu.Unlock() 554 return err 555 } 556 executionTime, ok := s.executionTimes[method] 557 s.mu.Unlock() 558 if ok { 559 var randTime int64 560 if executionTime.RandomExecutionTime > 0 { 561 randTime = rand.Int63n(int64(executionTime.RandomExecutionTime)) 562 } 563 totalExecutionTime := time.Duration(int64(executionTime.MinimumExecutionTime) + randTime) 564 <-time.After(totalExecutionTime) 565 s.mu.Lock() 566 if executionTime.Errors != nil && len(executionTime.Errors) > 0 { 567 err := executionTime.Errors[0] 568 if !executionTime.KeepError { 569 executionTime.Errors = executionTime.Errors[1:] 570 } 571 s.mu.Unlock() 572 return err 573 } 574 s.mu.Unlock() 575 } 576 return nil 577} 578 579func (s *inMemSpannerServer) CreateSession(ctx context.Context, req *spannerpb.CreateSessionRequest) (*spannerpb.Session, error) { 580 if err := s.simulateExecutionTime(MethodCreateSession, req); err != nil { 581 return nil, err 582 } 583 if req.Database == "" { 584 return nil, gstatus.Error(codes.InvalidArgument, "Missing database") 585 } 586 s.mu.Lock() 587 defer s.mu.Unlock() 588 if s.maxSessionsReturnedByServerInTotal > int32(0) && int32(len(s.sessions)) == s.maxSessionsReturnedByServerInTotal { 589 return nil, gstatus.Error(codes.ResourceExhausted, "No more sessions available") 590 } 591 sessionName := s.generateSessionNameLocked(req.Database) 592 ts := getCurrentTimestamp() 593 session := &spannerpb.Session{Name: sessionName, CreateTime: ts, ApproximateLastUseTime: ts} 594 s.totalSessionsCreated++ 595 s.sessions[sessionName] = session 596 return session, nil 597} 598 599func (s *inMemSpannerServer) BatchCreateSessions(ctx context.Context, req *spannerpb.BatchCreateSessionsRequest) (*spannerpb.BatchCreateSessionsResponse, error) { 600 if err := s.simulateExecutionTime(MethodBatchCreateSession, req); err != nil { 601 return nil, err 602 } 603 if req.Database == "" { 604 return nil, gstatus.Error(codes.InvalidArgument, "Missing database") 605 } 606 if req.SessionCount <= 0 { 607 return nil, gstatus.Error(codes.InvalidArgument, "Session count must be >= 0") 608 } 609 sessionsToCreate := req.SessionCount 610 s.mu.Lock() 611 defer s.mu.Unlock() 612 if s.maxSessionsReturnedByServerInTotal > int32(0) && int32(len(s.sessions)) >= s.maxSessionsReturnedByServerInTotal { 613 return nil, gstatus.Error(codes.ResourceExhausted, "No more sessions available") 614 } 615 if sessionsToCreate > s.maxSessionsReturnedByServerPerBatchRequest { 616 sessionsToCreate = s.maxSessionsReturnedByServerPerBatchRequest 617 } 618 if s.maxSessionsReturnedByServerInTotal > int32(0) && (sessionsToCreate+int32(len(s.sessions))) > s.maxSessionsReturnedByServerInTotal { 619 sessionsToCreate = s.maxSessionsReturnedByServerInTotal - int32(len(s.sessions)) 620 } 621 sessions := make([]*spannerpb.Session, sessionsToCreate) 622 for i := int32(0); i < sessionsToCreate; i++ { 623 sessionName := s.generateSessionNameLocked(req.Database) 624 ts := getCurrentTimestamp() 625 sessions[i] = &spannerpb.Session{Name: sessionName, CreateTime: ts, ApproximateLastUseTime: ts} 626 s.totalSessionsCreated++ 627 s.sessions[sessionName] = sessions[i] 628 } 629 return &spannerpb.BatchCreateSessionsResponse{Session: sessions}, nil 630} 631 632func (s *inMemSpannerServer) GetSession(ctx context.Context, req *spannerpb.GetSessionRequest) (*spannerpb.Session, error) { 633 if err := s.simulateExecutionTime(MethodGetSession, req); err != nil { 634 return nil, err 635 } 636 s.mu.Lock() 637 s.pings = append(s.pings, req.Name) 638 s.mu.Unlock() 639 if req.Name == "" { 640 return nil, gstatus.Error(codes.InvalidArgument, "Missing session name") 641 } 642 session, err := s.findSession(req.Name) 643 if err != nil { 644 return nil, err 645 } 646 return session, nil 647} 648 649func (s *inMemSpannerServer) ListSessions(ctx context.Context, req *spannerpb.ListSessionsRequest) (*spannerpb.ListSessionsResponse, error) { 650 s.mu.Lock() 651 if s.stopped { 652 s.mu.Unlock() 653 return nil, gstatus.Error(codes.Unavailable, "server has been stopped") 654 } 655 s.receivedRequests <- req 656 s.mu.Unlock() 657 if req.Database == "" { 658 return nil, gstatus.Error(codes.InvalidArgument, "Missing database") 659 } 660 expectedSessionName := req.Database + "/sessions/" 661 var sessions []*spannerpb.Session 662 s.mu.Lock() 663 for _, session := range s.sessions { 664 if strings.Index(session.Name, expectedSessionName) == 0 { 665 sessions = append(sessions, session) 666 } 667 } 668 s.mu.Unlock() 669 sort.Slice(sessions[:], func(i, j int) bool { 670 return sessions[i].Name < sessions[j].Name 671 }) 672 res := &spannerpb.ListSessionsResponse{Sessions: sessions} 673 return res, nil 674} 675 676func (s *inMemSpannerServer) DeleteSession(ctx context.Context, req *spannerpb.DeleteSessionRequest) (*emptypb.Empty, error) { 677 if err := s.simulateExecutionTime(MethodDeleteSession, req); err != nil { 678 return nil, err 679 } 680 if req.Name == "" { 681 return nil, gstatus.Error(codes.InvalidArgument, "Missing session name") 682 } 683 if _, err := s.findSession(req.Name); err != nil { 684 return nil, err 685 } 686 s.mu.Lock() 687 defer s.mu.Unlock() 688 s.totalSessionsDeleted++ 689 delete(s.sessions, req.Name) 690 return &emptypb.Empty{}, nil 691} 692 693func (s *inMemSpannerServer) ExecuteSql(ctx context.Context, req *spannerpb.ExecuteSqlRequest) (*spannerpb.ResultSet, error) { 694 if err := s.simulateExecutionTime(MethodExecuteSql, req); err != nil { 695 return nil, err 696 } 697 if req.Session == "" { 698 return nil, gstatus.Error(codes.InvalidArgument, "Missing session name") 699 } 700 session, err := s.findSession(req.Session) 701 if err != nil { 702 return nil, err 703 } 704 var id []byte 705 s.updateSessionLastUseTime(session.Name) 706 if id = s.getTransactionID(session, req.Transaction); id != nil { 707 _, err = s.getTransactionByID(id) 708 if err != nil { 709 return nil, err 710 } 711 } 712 statementResult, err := s.getStatementResult(req.Sql) 713 if err != nil { 714 return nil, err 715 } 716 s.mu.Lock() 717 isPartitionedDml := s.partitionedDmlTransactions[string(id)] 718 s.mu.Unlock() 719 switch statementResult.Type { 720 case StatementResultError: 721 return nil, statementResult.Err 722 case StatementResultResultSet: 723 return statementResult.ResultSet, nil 724 case StatementResultUpdateCount: 725 return statementResult.convertUpdateCountToResultSet(!isPartitionedDml), nil 726 } 727 return nil, gstatus.Error(codes.Internal, "Unknown result type") 728} 729 730func (s *inMemSpannerServer) ExecuteStreamingSql(req *spannerpb.ExecuteSqlRequest, stream spannerpb.Spanner_ExecuteStreamingSqlServer) error { 731 if err := s.simulateExecutionTime(MethodExecuteStreamingSql, req); err != nil { 732 return err 733 } 734 if req.Session == "" { 735 return gstatus.Error(codes.InvalidArgument, "Missing session name") 736 } 737 session, err := s.findSession(req.Session) 738 if err != nil { 739 return err 740 } 741 s.updateSessionLastUseTime(session.Name) 742 var id []byte 743 if id = s.getTransactionID(session, req.Transaction); id != nil { 744 _, err = s.getTransactionByID(id) 745 if err != nil { 746 return err 747 } 748 } 749 statementResult, err := s.getStatementResult(req.Sql) 750 if err != nil { 751 return err 752 } 753 s.mu.Lock() 754 isPartitionedDml := s.partitionedDmlTransactions[string(id)] 755 s.mu.Unlock() 756 switch statementResult.Type { 757 case StatementResultError: 758 return statementResult.Err 759 case StatementResultResultSet: 760 parts, err := statementResult.toPartialResultSets(req.ResumeToken) 761 if err != nil { 762 return err 763 } 764 var nextPartialResultSetError *PartialResultSetExecutionTime 765 s.mu.Lock() 766 pErrors := s.partialResultSetErrors[req.Sql] 767 if len(pErrors) > 0 { 768 nextPartialResultSetError = pErrors[0] 769 s.partialResultSetErrors[req.Sql] = pErrors[1:] 770 } 771 s.mu.Unlock() 772 for _, part := range parts { 773 if nextPartialResultSetError != nil && bytes.Equal(part.ResumeToken, nextPartialResultSetError.ResumeToken) { 774 if nextPartialResultSetError.ExecutionTime > 0 { 775 <-time.After(nextPartialResultSetError.ExecutionTime) 776 } 777 if nextPartialResultSetError.Err != nil { 778 return nextPartialResultSetError.Err 779 } 780 } 781 if err := stream.Send(part); err != nil { 782 return err 783 } 784 } 785 return nil 786 case StatementResultUpdateCount: 787 part := statementResult.updateCountToPartialResultSet(!isPartitionedDml) 788 if err := stream.Send(part); err != nil { 789 return err 790 } 791 return nil 792 } 793 return gstatus.Error(codes.Internal, "Unknown result type") 794} 795 796func (s *inMemSpannerServer) ExecuteBatchDml(ctx context.Context, req *spannerpb.ExecuteBatchDmlRequest) (*spannerpb.ExecuteBatchDmlResponse, error) { 797 if err := s.simulateExecutionTime(MethodExecuteBatchDml, req); err != nil { 798 return nil, err 799 } 800 if req.Session == "" { 801 return nil, gstatus.Error(codes.InvalidArgument, "Missing session name") 802 } 803 session, err := s.findSession(req.Session) 804 if err != nil { 805 return nil, err 806 } 807 s.updateSessionLastUseTime(session.Name) 808 var id []byte 809 if id = s.getTransactionID(session, req.Transaction); id != nil { 810 _, err = s.getTransactionByID(id) 811 if err != nil { 812 return nil, err 813 } 814 } 815 s.mu.Lock() 816 isPartitionedDml := s.partitionedDmlTransactions[string(id)] 817 s.mu.Unlock() 818 resp := &spannerpb.ExecuteBatchDmlResponse{} 819 resp.ResultSets = make([]*spannerpb.ResultSet, len(req.Statements)) 820 for idx, batchStatement := range req.Statements { 821 statementResult, err := s.getStatementResult(batchStatement.Sql) 822 if err != nil { 823 return nil, err 824 } 825 switch statementResult.Type { 826 case StatementResultError: 827 resp.Status = &status.Status{Code: int32(codes.Unknown)} 828 case StatementResultResultSet: 829 return nil, gstatus.Error(codes.InvalidArgument, fmt.Sprintf("Not an update statement: %v", batchStatement.Sql)) 830 case StatementResultUpdateCount: 831 resp.ResultSets[idx] = statementResult.convertUpdateCountToResultSet(!isPartitionedDml) 832 resp.Status = &status.Status{Code: int32(codes.OK)} 833 } 834 } 835 return resp, nil 836} 837 838func (s *inMemSpannerServer) Read(ctx context.Context, req *spannerpb.ReadRequest) (*spannerpb.ResultSet, error) { 839 s.mu.Lock() 840 if s.stopped { 841 s.mu.Unlock() 842 return nil, gstatus.Error(codes.Unavailable, "server has been stopped") 843 } 844 s.receivedRequests <- req 845 s.mu.Unlock() 846 return nil, gstatus.Error(codes.Unimplemented, "Method not yet implemented") 847} 848 849func (s *inMemSpannerServer) StreamingRead(req *spannerpb.ReadRequest, stream spannerpb.Spanner_StreamingReadServer) error { 850 s.mu.Lock() 851 if s.stopped { 852 s.mu.Unlock() 853 return gstatus.Error(codes.Unavailable, "server has been stopped") 854 } 855 s.receivedRequests <- req 856 s.mu.Unlock() 857 return gstatus.Error(codes.Unimplemented, "Method not yet implemented") 858} 859 860func (s *inMemSpannerServer) BeginTransaction(ctx context.Context, req *spannerpb.BeginTransactionRequest) (*spannerpb.Transaction, error) { 861 if err := s.simulateExecutionTime(MethodBeginTransaction, req); err != nil { 862 return nil, err 863 } 864 if req.Session == "" { 865 return nil, gstatus.Error(codes.InvalidArgument, "Missing session name") 866 } 867 session, err := s.findSession(req.Session) 868 if err != nil { 869 return nil, err 870 } 871 s.updateSessionLastUseTime(session.Name) 872 tx := s.beginTransaction(session, req.Options) 873 return tx, nil 874} 875 876func (s *inMemSpannerServer) Commit(ctx context.Context, req *spannerpb.CommitRequest) (*spannerpb.CommitResponse, error) { 877 if err := s.simulateExecutionTime(MethodCommitTransaction, req); err != nil { 878 return nil, err 879 } 880 if req.Session == "" { 881 return nil, gstatus.Error(codes.InvalidArgument, "Missing session name") 882 } 883 session, err := s.findSession(req.Session) 884 if err != nil { 885 return nil, err 886 } 887 s.updateSessionLastUseTime(session.Name) 888 var tx *spannerpb.Transaction 889 if req.GetSingleUseTransaction() != nil { 890 tx = s.beginTransaction(session, req.GetSingleUseTransaction()) 891 } else if req.GetTransactionId() != nil { 892 tx, err = s.getTransactionByID(req.GetTransactionId()) 893 if err != nil { 894 return nil, err 895 } 896 } else { 897 return nil, gstatus.Error(codes.InvalidArgument, "Missing transaction in commit request") 898 } 899 s.removeTransaction(tx) 900 return &spannerpb.CommitResponse{CommitTimestamp: getCurrentTimestamp()}, nil 901} 902 903func (s *inMemSpannerServer) Rollback(ctx context.Context, req *spannerpb.RollbackRequest) (*emptypb.Empty, error) { 904 s.mu.Lock() 905 if s.stopped { 906 s.mu.Unlock() 907 return nil, gstatus.Error(codes.Unavailable, "server has been stopped") 908 } 909 s.receivedRequests <- req 910 s.mu.Unlock() 911 if req.Session == "" { 912 return nil, gstatus.Error(codes.InvalidArgument, "Missing session name") 913 } 914 session, err := s.findSession(req.Session) 915 if err != nil { 916 return nil, err 917 } 918 s.updateSessionLastUseTime(session.Name) 919 tx, err := s.getTransactionByID(req.TransactionId) 920 if err != nil { 921 return nil, err 922 } 923 s.removeTransaction(tx) 924 return &emptypb.Empty{}, nil 925} 926 927func (s *inMemSpannerServer) PartitionQuery(ctx context.Context, req *spannerpb.PartitionQueryRequest) (*spannerpb.PartitionResponse, error) { 928 s.mu.Lock() 929 if s.stopped { 930 s.mu.Unlock() 931 return nil, gstatus.Error(codes.Unavailable, "server has been stopped") 932 } 933 s.receivedRequests <- req 934 s.mu.Unlock() 935 if req.Session == "" { 936 return nil, gstatus.Error(codes.InvalidArgument, "Missing session name") 937 } 938 session, err := s.findSession(req.Session) 939 if err != nil { 940 return nil, err 941 } 942 var id []byte 943 var tx *spannerpb.Transaction 944 s.updateSessionLastUseTime(session.Name) 945 if id = s.getTransactionID(session, req.Transaction); id != nil { 946 tx, err = s.getTransactionByID(id) 947 if err != nil { 948 return nil, err 949 } 950 } 951 var partitions []*spannerpb.Partition 952 for i := int64(0); i < req.PartitionOptions.MaxPartitions; i++ { 953 token := make([]byte, 10) 954 _, err := rand.Read(token) 955 if err != nil { 956 return nil, gstatus.Error(codes.Internal, "failed to generate random partition token") 957 } 958 partitions = append(partitions, &spannerpb.Partition{PartitionToken: token}) 959 } 960 return &spannerpb.PartitionResponse{ 961 Partitions: partitions, 962 Transaction: tx, 963 }, nil 964} 965 966func (s *inMemSpannerServer) PartitionRead(ctx context.Context, req *spannerpb.PartitionReadRequest) (*spannerpb.PartitionResponse, error) { 967 s.mu.Lock() 968 if s.stopped { 969 s.mu.Unlock() 970 return nil, gstatus.Error(codes.Unavailable, "server has been stopped") 971 } 972 s.receivedRequests <- req 973 s.mu.Unlock() 974 return nil, gstatus.Error(codes.Unimplemented, "Method not yet implemented") 975} 976