1// Copyright 2019 Google LLC 2// 3// Licensed under the Apache License, Version 2.0 (the "License"); 4// you may not use this file except in compliance with the License. 5// You may obtain a copy of the License at 6// 7// https://www.apache.org/licenses/LICENSE-2.0 8// 9// Unless required by applicable law or agreed to in writing, software 10// distributed under the License is distributed on an "AS IS" BASIS, 11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12// See the License for the specific language governing permissions and 13// limitations under the License. 14 15package testutil 16 17import ( 18 "bytes" 19 "context" 20 "encoding/binary" 21 "fmt" 22 "math/rand" 23 "sort" 24 "strings" 25 "sync" 26 "time" 27 28 "github.com/golang/protobuf/ptypes" 29 emptypb "github.com/golang/protobuf/ptypes/empty" 30 structpb "github.com/golang/protobuf/ptypes/struct" 31 "github.com/golang/protobuf/ptypes/timestamp" 32 "google.golang.org/genproto/googleapis/rpc/errdetails" 33 "google.golang.org/genproto/googleapis/rpc/status" 34 spannerpb "google.golang.org/genproto/googleapis/spanner/v1" 35 "google.golang.org/grpc/codes" 36 gstatus "google.golang.org/grpc/status" 37) 38 39var ( 40 // KvMeta is the Metadata for mocked KV table. 41 KvMeta = spannerpb.ResultSetMetadata{ 42 RowType: &spannerpb.StructType{ 43 Fields: []*spannerpb.StructType_Field{ 44 { 45 Name: "Key", 46 Type: &spannerpb.Type{Code: spannerpb.TypeCode_STRING}, 47 }, 48 { 49 Name: "Value", 50 Type: &spannerpb.Type{Code: spannerpb.TypeCode_STRING}, 51 }, 52 }, 53 }, 54 } 55) 56 57// StatementResultType indicates the type of result returned by a SQL 58// statement. 59type StatementResultType int 60 61const ( 62 // StatementResultError indicates that the sql statement returns an error. 63 StatementResultError StatementResultType = 0 64 // StatementResultResultSet indicates that the sql statement returns a 65 // result set. 66 StatementResultResultSet StatementResultType = 1 67 // StatementResultUpdateCount indicates that the sql statement returns an 68 // update count. 69 StatementResultUpdateCount StatementResultType = 2 70 // MaxRowsPerPartialResultSet is the maximum number of rows returned in 71 // each PartialResultSet. This number is deliberately set to a low value to 72 // ensure that most queries return more than one PartialResultSet. 73 MaxRowsPerPartialResultSet = 1 74) 75 76// The method names that can be used to register execution times and errors. 77const ( 78 MethodBeginTransaction string = "BEGIN_TRANSACTION" 79 MethodCommitTransaction string = "COMMIT_TRANSACTION" 80 MethodBatchCreateSession string = "BATCH_CREATE_SESSION" 81 MethodCreateSession string = "CREATE_SESSION" 82 MethodDeleteSession string = "DELETE_SESSION" 83 MethodGetSession string = "GET_SESSION" 84 MethodExecuteSql string = "EXECUTE_SQL" 85 MethodExecuteStreamingSql string = "EXECUTE_STREAMING_SQL" 86 MethodExecuteBatchDml string = "EXECUTE_BATCH_DML" 87 MethodStreamingRead string = "EXECUTE_STREAMING_READ" 88) 89 90// StatementResult represents a mocked result on the test server. The result is 91// either of: a ResultSet, an update count or an error. 92type StatementResult struct { 93 Type StatementResultType 94 Err error 95 ResultSet *spannerpb.ResultSet 96 UpdateCount int64 97 ResumeTokens [][]byte 98} 99 100// PartialResultSetExecutionTime represents execution times and errors that 101// should be used when a PartialResult at the specified resume token is to 102// be returned. 103type PartialResultSetExecutionTime struct { 104 ResumeToken []byte 105 ExecutionTime time.Duration 106 Err error 107} 108 109// ToPartialResultSets converts a ResultSet to a PartialResultSet. This method 110// is used to convert a mocked result to a PartialResultSet when one of the 111// streaming methods are called. 112func (s *StatementResult) ToPartialResultSets(resumeToken []byte) (result []*spannerpb.PartialResultSet, err error) { 113 var startIndex uint64 114 if len(resumeToken) > 0 { 115 if startIndex, err = DecodeResumeToken(resumeToken); err != nil { 116 return nil, err 117 } 118 } 119 120 totalRows := uint64(len(s.ResultSet.Rows)) 121 if totalRows > 0 { 122 for { 123 rowCount := min(totalRows-startIndex, uint64(MaxRowsPerPartialResultSet)) 124 rows := s.ResultSet.Rows[startIndex : startIndex+rowCount] 125 values := make([]*structpb.Value, 126 len(rows)*len(s.ResultSet.Metadata.RowType.Fields)) 127 var idx int 128 for _, row := range rows { 129 for colIdx := range s.ResultSet.Metadata.RowType.Fields { 130 values[idx] = row.Values[colIdx] 131 idx++ 132 } 133 } 134 var rt []byte 135 if len(s.ResumeTokens) == 0 { 136 rt = EncodeResumeToken(startIndex + rowCount) 137 } else { 138 rt = s.ResumeTokens[startIndex] 139 } 140 result = append(result, &spannerpb.PartialResultSet{ 141 Metadata: s.ResultSet.Metadata, 142 Values: values, 143 ResumeToken: rt, 144 }) 145 146 startIndex += rowCount 147 if startIndex == totalRows { 148 break 149 } 150 } 151 } else { 152 result = append(result, &spannerpb.PartialResultSet{ 153 Metadata: s.ResultSet.Metadata, 154 }) 155 } 156 return result, nil 157} 158 159func min(x, y uint64) uint64 { 160 if x > y { 161 return y 162 } 163 return x 164} 165 166func (s *StatementResult) updateCountToPartialResultSet(exact bool) *spannerpb.PartialResultSet { 167 return &spannerpb.PartialResultSet{ 168 Stats: s.convertUpdateCountToResultSet(exact).Stats, 169 } 170} 171 172// Converts an update count to a ResultSet, as DML statements also return the 173// update count as the statistics of a ResultSet. 174func (s *StatementResult) convertUpdateCountToResultSet(exact bool) *spannerpb.ResultSet { 175 if exact { 176 return &spannerpb.ResultSet{ 177 Stats: &spannerpb.ResultSetStats{ 178 RowCount: &spannerpb.ResultSetStats_RowCountExact{ 179 RowCountExact: s.UpdateCount, 180 }, 181 }, 182 } 183 } 184 return &spannerpb.ResultSet{ 185 Stats: &spannerpb.ResultSetStats{ 186 RowCount: &spannerpb.ResultSetStats_RowCountLowerBound{ 187 RowCountLowerBound: s.UpdateCount, 188 }, 189 }, 190 } 191} 192 193// SimulatedExecutionTime represents the time the execution of a method 194// should take, and any errors that should be returned by the method. 195type SimulatedExecutionTime struct { 196 MinimumExecutionTime time.Duration 197 RandomExecutionTime time.Duration 198 Errors []error 199 // Keep error after execution. The error will continue to be returned until 200 // it is cleared. 201 KeepError bool 202} 203 204// InMemSpannerServer contains the SpannerServer interface plus a couple 205// of specific methods for adding mocked results and resetting the server. 206type InMemSpannerServer interface { 207 spannerpb.SpannerServer 208 209 // Stops this server. 210 Stop() 211 212 // Resets the in-mem server to its default state, deleting all sessions and 213 // transactions that have been created on the server. Mocked results are 214 // not deleted. 215 Reset() 216 217 // Sets an error that will be returned by the next server call. The server 218 // call will also automatically clear the error. 219 SetError(err error) 220 221 // Puts a mocked result on the server for a specific sql statement. The 222 // server does not parse the SQL string in any way, it is merely used as 223 // a key to the mocked result. The result will be used for all methods that 224 // expect a SQL statement, including (batch) DML methods. 225 PutStatementResult(sql string, result *StatementResult) error 226 227 // Puts a mocked result on the server for a specific partition token. The 228 // result will only be used for query requests that specify a partition 229 // token. 230 PutPartitionResult(partitionToken []byte, result *StatementResult) error 231 232 // Adds a PartialResultSetExecutionTime to the server that should be returned 233 // for the specified SQL string. 234 AddPartialResultSetError(sql string, err PartialResultSetExecutionTime) 235 236 // Removes a mocked result on the server for a specific sql statement. 237 RemoveStatementResult(sql string) 238 239 // Aborts the specified transaction . This method can be used to test 240 // transaction retry logic. 241 AbortTransaction(id []byte) 242 243 // Puts a simulated execution time for one of the Spanner methods. 244 PutExecutionTime(method string, executionTime SimulatedExecutionTime) 245 // Freeze stalls all requests. 246 Freeze() 247 // Unfreeze restores processing requests. 248 Unfreeze() 249 250 TotalSessionsCreated() uint 251 TotalSessionsDeleted() uint 252 SetMaxSessionsReturnedByServerPerBatchRequest(sessionCount int32) 253 SetMaxSessionsReturnedByServerInTotal(sessionCount int32) 254 255 ReceivedRequests() chan interface{} 256 DumpSessions() map[string]bool 257 ClearPings() 258 DumpPings() []string 259} 260 261type inMemSpannerServer struct { 262 // Embed for forward compatibility. 263 // Tests will keep working if more methods are added 264 // in the future. 265 spannerpb.SpannerServer 266 267 mu sync.Mutex 268 // Set to true when this server been stopped. This is the end state of a 269 // server, a stopped server cannot be restarted. 270 stopped bool 271 // If set, all calls return this error. 272 err error 273 // The mock server creates session IDs using this counter. 274 sessionCounter uint64 275 // The sessions that have been created on this mock server. 276 sessions map[string]*spannerpb.Session 277 // Last use times per session. 278 sessionLastUseTime map[string]time.Time 279 // The mock server creates transaction IDs per session using these 280 // counters. 281 transactionCounters map[string]*uint64 282 // The transactions that have been created on this mock server. 283 transactions map[string]*spannerpb.Transaction 284 // The transactions that have been (manually) aborted on the server. 285 abortedTransactions map[string]bool 286 // The transactions that are marked as PartitionedDMLTransaction 287 partitionedDmlTransactions map[string]bool 288 // The mocked results for this server. 289 statementResults map[string]*StatementResult 290 partitionResults map[string]*StatementResult 291 // The simulated execution times per method. 292 executionTimes map[string]*SimulatedExecutionTime 293 // The simulated errors for partial result sets 294 partialResultSetErrors map[string][]*PartialResultSetExecutionTime 295 296 totalSessionsCreated uint 297 totalSessionsDeleted uint 298 // The maximum number of sessions that will be created per batch request. 299 maxSessionsReturnedByServerPerBatchRequest int32 300 maxSessionsReturnedByServerInTotal int32 301 receivedRequests chan interface{} 302 // Session ping history. 303 pings []string 304 305 // Server will stall on any requests. 306 freezed chan struct{} 307} 308 309// NewInMemSpannerServer creates a new in-mem test server. 310func NewInMemSpannerServer() InMemSpannerServer { 311 res := &inMemSpannerServer{} 312 res.initDefaults() 313 res.statementResults = make(map[string]*StatementResult) 314 res.partitionResults = make(map[string]*StatementResult) 315 res.executionTimes = make(map[string]*SimulatedExecutionTime) 316 res.partialResultSetErrors = make(map[string][]*PartialResultSetExecutionTime) 317 res.receivedRequests = make(chan interface{}, 1000000) 318 // Produce a closed channel, so the default action of ready is to not block. 319 res.Freeze() 320 res.Unfreeze() 321 return res 322} 323 324func (s *inMemSpannerServer) Stop() { 325 s.mu.Lock() 326 defer s.mu.Unlock() 327 s.stopped = true 328 close(s.receivedRequests) 329} 330 331// Resets the test server to its initial state, deleting all sessions and 332// transactions that have been created on the server. This method will not 333// remove mocked results. 334func (s *inMemSpannerServer) Reset() { 335 s.mu.Lock() 336 defer s.mu.Unlock() 337 close(s.receivedRequests) 338 s.receivedRequests = make(chan interface{}, 1000000) 339 s.initDefaults() 340} 341 342func (s *inMemSpannerServer) SetError(err error) { 343 s.mu.Lock() 344 defer s.mu.Unlock() 345 s.err = err 346} 347 348// Registers a mocked result for a SQL statement on the server. 349func (s *inMemSpannerServer) PutStatementResult(sql string, result *StatementResult) error { 350 s.mu.Lock() 351 defer s.mu.Unlock() 352 s.statementResults[sql] = result 353 return nil 354} 355 356func (s *inMemSpannerServer) RemoveStatementResult(sql string) { 357 s.mu.Lock() 358 defer s.mu.Unlock() 359 delete(s.statementResults, sql) 360} 361 362// Registers a mocked result for a partition token on the server. 363func (s *inMemSpannerServer) PutPartitionResult(partitionToken []byte, result *StatementResult) error { 364 tokenString := string(partitionToken) 365 s.mu.Lock() 366 defer s.mu.Unlock() 367 s.partitionResults[tokenString] = result 368 return nil 369} 370 371func (s *inMemSpannerServer) AbortTransaction(id []byte) { 372 s.mu.Lock() 373 defer s.mu.Unlock() 374 s.abortedTransactions[string(id)] = true 375} 376 377func (s *inMemSpannerServer) PutExecutionTime(method string, executionTime SimulatedExecutionTime) { 378 s.mu.Lock() 379 defer s.mu.Unlock() 380 s.executionTimes[method] = &executionTime 381} 382 383func (s *inMemSpannerServer) AddPartialResultSetError(sql string, partialResultSetError PartialResultSetExecutionTime) { 384 s.mu.Lock() 385 defer s.mu.Unlock() 386 s.partialResultSetErrors[sql] = append(s.partialResultSetErrors[sql], &partialResultSetError) 387} 388 389// Freeze stalls all requests. 390func (s *inMemSpannerServer) Freeze() { 391 s.mu.Lock() 392 defer s.mu.Unlock() 393 s.freezed = make(chan struct{}) 394} 395 396// Unfreeze restores processing requests. 397func (s *inMemSpannerServer) Unfreeze() { 398 s.mu.Lock() 399 defer s.mu.Unlock() 400 close(s.freezed) 401} 402 403// ready checks conditions before executing requests 404func (s *inMemSpannerServer) ready() { 405 s.mu.Lock() 406 freezed := s.freezed 407 s.mu.Unlock() 408 // check if server should be freezed 409 <-freezed 410} 411 412func (s *inMemSpannerServer) TotalSessionsCreated() uint { 413 s.mu.Lock() 414 defer s.mu.Unlock() 415 return s.totalSessionsCreated 416} 417 418func (s *inMemSpannerServer) TotalSessionsDeleted() uint { 419 s.mu.Lock() 420 defer s.mu.Unlock() 421 return s.totalSessionsDeleted 422} 423 424func (s *inMemSpannerServer) SetMaxSessionsReturnedByServerPerBatchRequest(sessionCount int32) { 425 s.mu.Lock() 426 defer s.mu.Unlock() 427 s.maxSessionsReturnedByServerPerBatchRequest = sessionCount 428} 429 430func (s *inMemSpannerServer) SetMaxSessionsReturnedByServerInTotal(sessionCount int32) { 431 s.mu.Lock() 432 defer s.mu.Unlock() 433 s.maxSessionsReturnedByServerInTotal = sessionCount 434} 435 436func (s *inMemSpannerServer) ReceivedRequests() chan interface{} { 437 return s.receivedRequests 438} 439 440// ClearPings clears the ping history from the server. 441func (s *inMemSpannerServer) ClearPings() { 442 s.mu.Lock() 443 defer s.mu.Unlock() 444 s.pings = nil 445} 446 447// DumpPings dumps the ping history. 448func (s *inMemSpannerServer) DumpPings() []string { 449 s.mu.Lock() 450 defer s.mu.Unlock() 451 return append([]string(nil), s.pings...) 452} 453 454// DumpSessions dumps the internal session table. 455func (s *inMemSpannerServer) DumpSessions() map[string]bool { 456 s.mu.Lock() 457 defer s.mu.Unlock() 458 st := map[string]bool{} 459 for s := range s.sessions { 460 st[s] = true 461 } 462 return st 463} 464 465func (s *inMemSpannerServer) initDefaults() { 466 s.sessionCounter = 0 467 s.maxSessionsReturnedByServerPerBatchRequest = 100 468 s.sessions = make(map[string]*spannerpb.Session) 469 s.sessionLastUseTime = make(map[string]time.Time) 470 s.transactions = make(map[string]*spannerpb.Transaction) 471 s.abortedTransactions = make(map[string]bool) 472 s.partitionedDmlTransactions = make(map[string]bool) 473 s.transactionCounters = make(map[string]*uint64) 474} 475 476func (s *inMemSpannerServer) generateSessionNameLocked(database string) string { 477 s.sessionCounter++ 478 return fmt.Sprintf("%s/sessions/%d", database, s.sessionCounter) 479} 480 481func (s *inMemSpannerServer) findSession(name string) (*spannerpb.Session, error) { 482 s.mu.Lock() 483 defer s.mu.Unlock() 484 session := s.sessions[name] 485 if session == nil { 486 return nil, newSessionNotFoundError(name) 487 } 488 return session, nil 489} 490 491// sessionResourceType is the type name of Spanner sessions. 492const sessionResourceType = "type.googleapis.com/google.spanner.v1.Session" 493 494func newSessionNotFoundError(name string) error { 495 s := gstatus.Newf(codes.NotFound, "Session not found: Session with id %s not found", name) 496 s, _ = s.WithDetails(&errdetails.ResourceInfo{ResourceType: sessionResourceType, ResourceName: name}) 497 return s.Err() 498} 499 500func (s *inMemSpannerServer) updateSessionLastUseTime(session string) { 501 s.mu.Lock() 502 defer s.mu.Unlock() 503 s.sessionLastUseTime[session] = time.Now() 504} 505 506func getCurrentTimestamp() *timestamp.Timestamp { 507 t := time.Now() 508 return ×tamp.Timestamp{Seconds: t.Unix(), Nanos: int32(t.Nanosecond())} 509} 510 511// Gets the transaction id from the transaction selector. If the selector 512// specifies that a new transaction should be started, this method will start 513// a new transaction and return the id of that transaction. 514func (s *inMemSpannerServer) getTransactionID(session *spannerpb.Session, txSelector *spannerpb.TransactionSelector) []byte { 515 var res []byte 516 if txSelector.GetBegin() != nil { 517 // Start a new transaction. 518 res = s.beginTransaction(session, txSelector.GetBegin()).Id 519 } else if txSelector.GetId() != nil { 520 res = txSelector.GetId() 521 } 522 return res 523} 524 525func (s *inMemSpannerServer) generateTransactionName(session string) string { 526 s.mu.Lock() 527 defer s.mu.Unlock() 528 counter, ok := s.transactionCounters[session] 529 if !ok { 530 counter = new(uint64) 531 s.transactionCounters[session] = counter 532 } 533 *counter++ 534 return fmt.Sprintf("%s/transactions/%d", session, *counter) 535} 536 537func (s *inMemSpannerServer) beginTransaction(session *spannerpb.Session, options *spannerpb.TransactionOptions) *spannerpb.Transaction { 538 id := s.generateTransactionName(session.Name) 539 res := &spannerpb.Transaction{ 540 Id: []byte(id), 541 ReadTimestamp: getCurrentTimestamp(), 542 } 543 s.mu.Lock() 544 s.transactions[id] = res 545 s.partitionedDmlTransactions[id] = options.GetPartitionedDml() != nil 546 s.mu.Unlock() 547 return res 548} 549 550func (s *inMemSpannerServer) getTransactionByID(id []byte) (*spannerpb.Transaction, error) { 551 s.mu.Lock() 552 defer s.mu.Unlock() 553 tx, ok := s.transactions[string(id)] 554 if !ok { 555 return nil, gstatus.Error(codes.NotFound, "Transaction not found") 556 } 557 aborted, ok := s.abortedTransactions[string(id)] 558 if ok && aborted { 559 return nil, newAbortedErrorWithMinimalRetryDelay() 560 } 561 return tx, nil 562} 563 564func newAbortedErrorWithMinimalRetryDelay() error { 565 st := gstatus.New(codes.Aborted, "Transaction has been aborted") 566 retry := &errdetails.RetryInfo{ 567 RetryDelay: ptypes.DurationProto(time.Nanosecond), 568 } 569 st, _ = st.WithDetails(retry) 570 return st.Err() 571} 572 573func (s *inMemSpannerServer) removeTransaction(tx *spannerpb.Transaction) { 574 s.mu.Lock() 575 defer s.mu.Unlock() 576 delete(s.transactions, string(tx.Id)) 577 delete(s.partitionedDmlTransactions, string(tx.Id)) 578} 579 580func (s *inMemSpannerServer) getPartitionResult(partitionToken []byte) (*StatementResult, error) { 581 tokenString := string(partitionToken) 582 s.mu.Lock() 583 defer s.mu.Unlock() 584 result, ok := s.partitionResults[tokenString] 585 if !ok { 586 return nil, gstatus.Error(codes.Internal, fmt.Sprintf("No result found for partition token %v", tokenString)) 587 } 588 return result, nil 589} 590 591func (s *inMemSpannerServer) getStatementResult(sql string) (*StatementResult, error) { 592 s.mu.Lock() 593 defer s.mu.Unlock() 594 result, ok := s.statementResults[sql] 595 if !ok { 596 return nil, gstatus.Error(codes.Internal, fmt.Sprintf("No result found for statement %v", sql)) 597 } 598 return result, nil 599} 600 601func (s *inMemSpannerServer) simulateExecutionTime(method string, req interface{}) error { 602 s.mu.Lock() 603 if s.stopped { 604 s.mu.Unlock() 605 return gstatus.Error(codes.Unavailable, "server has been stopped") 606 } 607 s.receivedRequests <- req 608 s.mu.Unlock() 609 s.ready() 610 s.mu.Lock() 611 if s.err != nil { 612 err := s.err 613 s.err = nil 614 s.mu.Unlock() 615 return err 616 } 617 executionTime, ok := s.executionTimes[method] 618 s.mu.Unlock() 619 if ok { 620 var randTime int64 621 if executionTime.RandomExecutionTime > 0 { 622 randTime = rand.Int63n(int64(executionTime.RandomExecutionTime)) 623 } 624 totalExecutionTime := time.Duration(int64(executionTime.MinimumExecutionTime) + randTime) 625 <-time.After(totalExecutionTime) 626 s.mu.Lock() 627 if executionTime.Errors != nil && len(executionTime.Errors) > 0 { 628 err := executionTime.Errors[0] 629 if !executionTime.KeepError { 630 executionTime.Errors = executionTime.Errors[1:] 631 } 632 s.mu.Unlock() 633 return err 634 } 635 s.mu.Unlock() 636 } 637 return nil 638} 639 640func (s *inMemSpannerServer) CreateSession(ctx context.Context, req *spannerpb.CreateSessionRequest) (*spannerpb.Session, error) { 641 if err := s.simulateExecutionTime(MethodCreateSession, req); err != nil { 642 return nil, err 643 } 644 if req.Database == "" { 645 return nil, gstatus.Error(codes.InvalidArgument, "Missing database") 646 } 647 s.mu.Lock() 648 defer s.mu.Unlock() 649 if s.maxSessionsReturnedByServerInTotal > int32(0) && int32(len(s.sessions)) == s.maxSessionsReturnedByServerInTotal { 650 return nil, gstatus.Error(codes.ResourceExhausted, "No more sessions available") 651 } 652 sessionName := s.generateSessionNameLocked(req.Database) 653 ts := getCurrentTimestamp() 654 session := &spannerpb.Session{Name: sessionName, CreateTime: ts, ApproximateLastUseTime: ts} 655 s.totalSessionsCreated++ 656 s.sessions[sessionName] = session 657 return session, nil 658} 659 660func (s *inMemSpannerServer) BatchCreateSessions(ctx context.Context, req *spannerpb.BatchCreateSessionsRequest) (*spannerpb.BatchCreateSessionsResponse, error) { 661 if err := s.simulateExecutionTime(MethodBatchCreateSession, req); err != nil { 662 return nil, err 663 } 664 if req.Database == "" { 665 return nil, gstatus.Error(codes.InvalidArgument, "Missing database") 666 } 667 if req.SessionCount <= 0 { 668 return nil, gstatus.Error(codes.InvalidArgument, "Session count must be >= 0") 669 } 670 sessionsToCreate := req.SessionCount 671 s.mu.Lock() 672 defer s.mu.Unlock() 673 if s.maxSessionsReturnedByServerInTotal > int32(0) && int32(len(s.sessions)) >= s.maxSessionsReturnedByServerInTotal { 674 return nil, gstatus.Error(codes.ResourceExhausted, "No more sessions available") 675 } 676 if sessionsToCreate > s.maxSessionsReturnedByServerPerBatchRequest { 677 sessionsToCreate = s.maxSessionsReturnedByServerPerBatchRequest 678 } 679 if s.maxSessionsReturnedByServerInTotal > int32(0) && (sessionsToCreate+int32(len(s.sessions))) > s.maxSessionsReturnedByServerInTotal { 680 sessionsToCreate = s.maxSessionsReturnedByServerInTotal - int32(len(s.sessions)) 681 } 682 sessions := make([]*spannerpb.Session, sessionsToCreate) 683 for i := int32(0); i < sessionsToCreate; i++ { 684 sessionName := s.generateSessionNameLocked(req.Database) 685 ts := getCurrentTimestamp() 686 sessions[i] = &spannerpb.Session{Name: sessionName, CreateTime: ts, ApproximateLastUseTime: ts} 687 s.totalSessionsCreated++ 688 s.sessions[sessionName] = sessions[i] 689 } 690 return &spannerpb.BatchCreateSessionsResponse{Session: sessions}, nil 691} 692 693func (s *inMemSpannerServer) GetSession(ctx context.Context, req *spannerpb.GetSessionRequest) (*spannerpb.Session, error) { 694 if err := s.simulateExecutionTime(MethodGetSession, req); err != nil { 695 return nil, err 696 } 697 if req.Name == "" { 698 return nil, gstatus.Error(codes.InvalidArgument, "Missing session name") 699 } 700 session, err := s.findSession(req.Name) 701 if err != nil { 702 return nil, err 703 } 704 return session, nil 705} 706 707func (s *inMemSpannerServer) ListSessions(ctx context.Context, req *spannerpb.ListSessionsRequest) (*spannerpb.ListSessionsResponse, error) { 708 s.mu.Lock() 709 if s.stopped { 710 s.mu.Unlock() 711 return nil, gstatus.Error(codes.Unavailable, "server has been stopped") 712 } 713 s.receivedRequests <- req 714 s.mu.Unlock() 715 if req.Database == "" { 716 return nil, gstatus.Error(codes.InvalidArgument, "Missing database") 717 } 718 expectedSessionName := req.Database + "/sessions/" 719 var sessions []*spannerpb.Session 720 s.mu.Lock() 721 for _, session := range s.sessions { 722 if strings.Index(session.Name, expectedSessionName) == 0 { 723 sessions = append(sessions, session) 724 } 725 } 726 s.mu.Unlock() 727 sort.Slice(sessions[:], func(i, j int) bool { 728 return sessions[i].Name < sessions[j].Name 729 }) 730 res := &spannerpb.ListSessionsResponse{Sessions: sessions} 731 return res, nil 732} 733 734func (s *inMemSpannerServer) DeleteSession(ctx context.Context, req *spannerpb.DeleteSessionRequest) (*emptypb.Empty, error) { 735 if err := s.simulateExecutionTime(MethodDeleteSession, req); err != nil { 736 return nil, err 737 } 738 if req.Name == "" { 739 return nil, gstatus.Error(codes.InvalidArgument, "Missing session name") 740 } 741 if _, err := s.findSession(req.Name); err != nil { 742 return nil, err 743 } 744 s.mu.Lock() 745 defer s.mu.Unlock() 746 s.totalSessionsDeleted++ 747 delete(s.sessions, req.Name) 748 return &emptypb.Empty{}, nil 749} 750 751func (s *inMemSpannerServer) ExecuteSql(ctx context.Context, req *spannerpb.ExecuteSqlRequest) (*spannerpb.ResultSet, error) { 752 if err := s.simulateExecutionTime(MethodExecuteSql, req); err != nil { 753 return nil, err 754 } 755 if req.Sql == "SELECT 1" { 756 s.mu.Lock() 757 s.pings = append(s.pings, req.Session) 758 s.mu.Unlock() 759 } 760 if req.Session == "" { 761 return nil, gstatus.Error(codes.InvalidArgument, "Missing session name") 762 } 763 session, err := s.findSession(req.Session) 764 if err != nil { 765 return nil, err 766 } 767 var id []byte 768 s.updateSessionLastUseTime(session.Name) 769 if id = s.getTransactionID(session, req.Transaction); id != nil { 770 _, err = s.getTransactionByID(id) 771 if err != nil { 772 return nil, err 773 } 774 } 775 var statementResult *StatementResult 776 if req.PartitionToken != nil { 777 statementResult, err = s.getPartitionResult(req.PartitionToken) 778 } else { 779 statementResult, err = s.getStatementResult(req.Sql) 780 } 781 if err != nil { 782 return nil, err 783 } 784 s.mu.Lock() 785 isPartitionedDml := s.partitionedDmlTransactions[string(id)] 786 s.mu.Unlock() 787 switch statementResult.Type { 788 case StatementResultError: 789 return nil, statementResult.Err 790 case StatementResultResultSet: 791 return statementResult.ResultSet, nil 792 case StatementResultUpdateCount: 793 return statementResult.convertUpdateCountToResultSet(!isPartitionedDml), nil 794 } 795 return nil, gstatus.Error(codes.Internal, "Unknown result type") 796} 797 798func (s *inMemSpannerServer) ExecuteStreamingSql(req *spannerpb.ExecuteSqlRequest, stream spannerpb.Spanner_ExecuteStreamingSqlServer) error { 799 if err := s.simulateExecutionTime(MethodExecuteStreamingSql, req); err != nil { 800 return err 801 } 802 return s.executeStreamingSQL(req, stream) 803} 804 805func (s *inMemSpannerServer) executeStreamingSQL(req *spannerpb.ExecuteSqlRequest, stream spannerpb.Spanner_ExecuteStreamingSqlServer) error { 806 if req.Session == "" { 807 return gstatus.Error(codes.InvalidArgument, "Missing session name") 808 } 809 session, err := s.findSession(req.Session) 810 if err != nil { 811 return err 812 } 813 s.updateSessionLastUseTime(session.Name) 814 var id []byte 815 if id = s.getTransactionID(session, req.Transaction); id != nil { 816 _, err = s.getTransactionByID(id) 817 if err != nil { 818 return err 819 } 820 } 821 var statementResult *StatementResult 822 if req.PartitionToken != nil { 823 statementResult, err = s.getPartitionResult(req.PartitionToken) 824 } else { 825 statementResult, err = s.getStatementResult(req.Sql) 826 } 827 if err != nil { 828 return err 829 } 830 s.mu.Lock() 831 isPartitionedDml := s.partitionedDmlTransactions[string(id)] 832 s.mu.Unlock() 833 switch statementResult.Type { 834 case StatementResultError: 835 return statementResult.Err 836 case StatementResultResultSet: 837 parts, err := statementResult.ToPartialResultSets(req.ResumeToken) 838 if err != nil { 839 return err 840 } 841 var nextPartialResultSetError *PartialResultSetExecutionTime 842 s.mu.Lock() 843 pErrors := s.partialResultSetErrors[req.Sql] 844 if len(pErrors) > 0 { 845 nextPartialResultSetError = pErrors[0] 846 s.partialResultSetErrors[req.Sql] = pErrors[1:] 847 } 848 s.mu.Unlock() 849 for _, part := range parts { 850 if nextPartialResultSetError != nil && bytes.Equal(part.ResumeToken, nextPartialResultSetError.ResumeToken) { 851 if nextPartialResultSetError.ExecutionTime > 0 { 852 <-time.After(nextPartialResultSetError.ExecutionTime) 853 } 854 if nextPartialResultSetError.Err != nil { 855 return nextPartialResultSetError.Err 856 } 857 } 858 if err := stream.Send(part); err != nil { 859 return err 860 } 861 } 862 return nil 863 case StatementResultUpdateCount: 864 part := statementResult.updateCountToPartialResultSet(!isPartitionedDml) 865 if err := stream.Send(part); err != nil { 866 return err 867 } 868 return nil 869 } 870 return gstatus.Error(codes.Internal, "Unknown result type") 871} 872 873func (s *inMemSpannerServer) ExecuteBatchDml(ctx context.Context, req *spannerpb.ExecuteBatchDmlRequest) (*spannerpb.ExecuteBatchDmlResponse, error) { 874 if err := s.simulateExecutionTime(MethodExecuteBatchDml, req); err != nil { 875 return nil, err 876 } 877 if req.Session == "" { 878 return nil, gstatus.Error(codes.InvalidArgument, "Missing session name") 879 } 880 session, err := s.findSession(req.Session) 881 if err != nil { 882 return nil, err 883 } 884 s.updateSessionLastUseTime(session.Name) 885 var id []byte 886 if id = s.getTransactionID(session, req.Transaction); id != nil { 887 _, err = s.getTransactionByID(id) 888 if err != nil { 889 return nil, err 890 } 891 } 892 s.mu.Lock() 893 isPartitionedDml := s.partitionedDmlTransactions[string(id)] 894 s.mu.Unlock() 895 resp := &spannerpb.ExecuteBatchDmlResponse{} 896 resp.ResultSets = make([]*spannerpb.ResultSet, len(req.Statements)) 897 resp.Status = &status.Status{Code: int32(codes.OK)} 898 for idx, batchStatement := range req.Statements { 899 statementResult, err := s.getStatementResult(batchStatement.Sql) 900 if err != nil { 901 return nil, err 902 } 903 switch statementResult.Type { 904 case StatementResultError: 905 resp.Status = &status.Status{Code: int32(gstatus.Code(statementResult.Err)), Message: statementResult.Err.Error()} 906 resp.ResultSets = resp.ResultSets[:idx] 907 return resp, nil 908 case StatementResultResultSet: 909 return nil, gstatus.Error(codes.InvalidArgument, fmt.Sprintf("Not an update statement: %v", batchStatement.Sql)) 910 case StatementResultUpdateCount: 911 resp.ResultSets[idx] = statementResult.convertUpdateCountToResultSet(!isPartitionedDml) 912 } 913 } 914 return resp, nil 915} 916 917func (s *inMemSpannerServer) Read(ctx context.Context, req *spannerpb.ReadRequest) (*spannerpb.ResultSet, error) { 918 s.mu.Lock() 919 if s.stopped { 920 s.mu.Unlock() 921 return nil, gstatus.Error(codes.Unavailable, "server has been stopped") 922 } 923 s.receivedRequests <- req 924 s.mu.Unlock() 925 return nil, gstatus.Error(codes.Unimplemented, "Method not yet implemented") 926} 927 928func (s *inMemSpannerServer) StreamingRead(req *spannerpb.ReadRequest, stream spannerpb.Spanner_StreamingReadServer) error { 929 if err := s.simulateExecutionTime(MethodStreamingRead, req); err != nil { 930 return err 931 } 932 sqlReq := &spannerpb.ExecuteSqlRequest{ 933 Session: req.Session, 934 Transaction: req.Transaction, 935 PartitionToken: req.PartitionToken, 936 ResumeToken: req.ResumeToken, 937 // KeySet is currently ignored. 938 Sql: fmt.Sprintf( 939 "SELECT %s FROM %s", 940 strings.Join(req.Columns, ", "), 941 req.Table, 942 ), 943 } 944 return s.executeStreamingSQL(sqlReq, stream) 945} 946 947func (s *inMemSpannerServer) BeginTransaction(ctx context.Context, req *spannerpb.BeginTransactionRequest) (*spannerpb.Transaction, error) { 948 if err := s.simulateExecutionTime(MethodBeginTransaction, req); err != nil { 949 return nil, err 950 } 951 if req.Session == "" { 952 return nil, gstatus.Error(codes.InvalidArgument, "Missing session name") 953 } 954 session, err := s.findSession(req.Session) 955 if err != nil { 956 return nil, err 957 } 958 s.updateSessionLastUseTime(session.Name) 959 tx := s.beginTransaction(session, req.Options) 960 return tx, nil 961} 962 963func (s *inMemSpannerServer) Commit(ctx context.Context, req *spannerpb.CommitRequest) (*spannerpb.CommitResponse, error) { 964 if err := s.simulateExecutionTime(MethodCommitTransaction, req); err != nil { 965 return nil, err 966 } 967 if req.Session == "" { 968 return nil, gstatus.Error(codes.InvalidArgument, "Missing session name") 969 } 970 session, err := s.findSession(req.Session) 971 if err != nil { 972 return nil, err 973 } 974 s.updateSessionLastUseTime(session.Name) 975 var tx *spannerpb.Transaction 976 if req.GetSingleUseTransaction() != nil { 977 tx = s.beginTransaction(session, req.GetSingleUseTransaction()) 978 } else if req.GetTransactionId() != nil { 979 tx, err = s.getTransactionByID(req.GetTransactionId()) 980 if err != nil { 981 return nil, err 982 } 983 } else { 984 return nil, gstatus.Error(codes.InvalidArgument, "Missing transaction in commit request") 985 } 986 s.removeTransaction(tx) 987 resp := &spannerpb.CommitResponse{CommitTimestamp: getCurrentTimestamp()} 988 if req.ReturnCommitStats { 989 resp.CommitStats = &spannerpb.CommitResponse_CommitStats{ 990 MutationCount: int64(1), 991 } 992 } 993 return resp, nil 994} 995 996func (s *inMemSpannerServer) Rollback(ctx context.Context, req *spannerpb.RollbackRequest) (*emptypb.Empty, error) { 997 s.mu.Lock() 998 if s.stopped { 999 s.mu.Unlock() 1000 return nil, gstatus.Error(codes.Unavailable, "server has been stopped") 1001 } 1002 s.receivedRequests <- req 1003 s.mu.Unlock() 1004 if req.Session == "" { 1005 return nil, gstatus.Error(codes.InvalidArgument, "Missing session name") 1006 } 1007 session, err := s.findSession(req.Session) 1008 if err != nil { 1009 return nil, err 1010 } 1011 s.updateSessionLastUseTime(session.Name) 1012 tx, err := s.getTransactionByID(req.TransactionId) 1013 if err != nil { 1014 return nil, err 1015 } 1016 s.removeTransaction(tx) 1017 return &emptypb.Empty{}, nil 1018} 1019 1020func (s *inMemSpannerServer) PartitionQuery(ctx context.Context, req *spannerpb.PartitionQueryRequest) (*spannerpb.PartitionResponse, error) { 1021 s.mu.Lock() 1022 if s.stopped { 1023 s.mu.Unlock() 1024 return nil, gstatus.Error(codes.Unavailable, "server has been stopped") 1025 } 1026 s.receivedRequests <- req 1027 s.mu.Unlock() 1028 if req.Session == "" { 1029 return nil, gstatus.Error(codes.InvalidArgument, "Missing session name") 1030 } 1031 session, err := s.findSession(req.Session) 1032 if err != nil { 1033 return nil, err 1034 } 1035 var id []byte 1036 var tx *spannerpb.Transaction 1037 s.updateSessionLastUseTime(session.Name) 1038 if id = s.getTransactionID(session, req.Transaction); id != nil { 1039 tx, err = s.getTransactionByID(id) 1040 if err != nil { 1041 return nil, err 1042 } 1043 } 1044 var partitions []*spannerpb.Partition 1045 for i := int64(0); i < req.PartitionOptions.MaxPartitions; i++ { 1046 token := make([]byte, 10) 1047 _, err := rand.Read(token) 1048 if err != nil { 1049 return nil, gstatus.Error(codes.Internal, "failed to generate random partition token") 1050 } 1051 partitions = append(partitions, &spannerpb.Partition{PartitionToken: token}) 1052 } 1053 return &spannerpb.PartitionResponse{ 1054 Partitions: partitions, 1055 Transaction: tx, 1056 }, nil 1057} 1058 1059func (s *inMemSpannerServer) PartitionRead(ctx context.Context, req *spannerpb.PartitionReadRequest) (*spannerpb.PartitionResponse, error) { 1060 s.mu.Lock() 1061 if s.stopped { 1062 s.mu.Unlock() 1063 return nil, gstatus.Error(codes.Unavailable, "server has been stopped") 1064 } 1065 s.receivedRequests <- req 1066 s.mu.Unlock() 1067 return nil, gstatus.Error(codes.Unimplemented, "Method not yet implemented") 1068} 1069 1070// EncodeResumeToken return mock resume token encoding for an uint64 integer. 1071func EncodeResumeToken(t uint64) []byte { 1072 rt := make([]byte, 16) 1073 binary.PutUvarint(rt, t) 1074 return rt 1075} 1076 1077// DecodeResumeToken decodes a mock resume token into an uint64 integer. 1078func DecodeResumeToken(t []byte) (uint64, error) { 1079 s, n := binary.Uvarint(t) 1080 if n <= 0 { 1081 return 0, fmt.Errorf("invalid resume token: %v", t) 1082 } 1083 return s, nil 1084} 1085