// Copyright 2019 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package testutil import ( "bytes" "context" "fmt" "math/rand" "sort" "strings" "sync" "time" "github.com/golang/protobuf/ptypes" emptypb "github.com/golang/protobuf/ptypes/empty" structpb "github.com/golang/protobuf/ptypes/struct" "github.com/golang/protobuf/ptypes/timestamp" "google.golang.org/genproto/googleapis/rpc/errdetails" "google.golang.org/genproto/googleapis/rpc/status" spannerpb "google.golang.org/genproto/googleapis/spanner/v1" "google.golang.org/grpc/codes" gstatus "google.golang.org/grpc/status" ) // StatementResultType indicates the type of result returned by a SQL // statement. type StatementResultType int const ( // StatementResultError indicates that the sql statement returns an error. StatementResultError StatementResultType = 0 // StatementResultResultSet indicates that the sql statement returns a // result set. StatementResultResultSet StatementResultType = 1 // StatementResultUpdateCount indicates that the sql statement returns an // update count. StatementResultUpdateCount StatementResultType = 2 // MaxRowsPerPartialResultSet is the maximum number of rows returned in // each PartialResultSet. This number is deliberately set to a low value to // ensure that most queries return more than one PartialResultSet. MaxRowsPerPartialResultSet = 1 ) // The method names that can be used to register execution times and errors. const ( MethodBeginTransaction string = "BEGIN_TRANSACTION" MethodCommitTransaction string = "COMMIT_TRANSACTION" MethodBatchCreateSession string = "BATCH_CREATE_SESSION" MethodCreateSession string = "CREATE_SESSION" MethodDeleteSession string = "DELETE_SESSION" MethodGetSession string = "GET_SESSION" MethodExecuteSql string = "EXECUTE_SQL" MethodExecuteStreamingSql string = "EXECUTE_STREAMING_SQL" MethodExecuteBatchDml string = "EXECUTE_BATCH_DML" ) // StatementResult represents a mocked result on the test server. The result is // either of: a ResultSet, an update count or an error. type StatementResult struct { Type StatementResultType Err error ResultSet *spannerpb.ResultSet UpdateCount int64 } // PartialResultSetExecutionTime represents execution times and errors that // should be used when a PartialResult at the specified resume token is to // be returned. type PartialResultSetExecutionTime struct { ResumeToken []byte ExecutionTime time.Duration Err error } // Converts a ResultSet to a PartialResultSet. This method is used to convert // a mocked result to a PartialResultSet when one of the streaming methods are // called. func (s *StatementResult) toPartialResultSets(resumeToken []byte) (result []*spannerpb.PartialResultSet, err error) { var startIndex uint64 if len(resumeToken) > 0 { if startIndex, err = DecodeResumeToken(resumeToken); err != nil { return nil, err } } totalRows := uint64(len(s.ResultSet.Rows)) for { rowCount := min(totalRows-startIndex, uint64(MaxRowsPerPartialResultSet)) rows := s.ResultSet.Rows[startIndex : startIndex+rowCount] values := make([]*structpb.Value, len(rows)*len(s.ResultSet.Metadata.RowType.Fields)) var idx int for _, row := range rows { for colIdx := range s.ResultSet.Metadata.RowType.Fields { values[idx] = row.Values[colIdx] idx++ } } result = append(result, &spannerpb.PartialResultSet{ Metadata: s.ResultSet.Metadata, Values: values, ResumeToken: EncodeResumeToken(startIndex + rowCount), }) startIndex += rowCount if startIndex == totalRows { break } } return result, nil } func min(x, y uint64) uint64 { if x > y { return y } return x } func (s *StatementResult) updateCountToPartialResultSet(exact bool) *spannerpb.PartialResultSet { return &spannerpb.PartialResultSet{ Stats: s.convertUpdateCountToResultSet(exact).Stats, } } // Converts an update count to a ResultSet, as DML statements also return the // update count as the statistics of a ResultSet. func (s *StatementResult) convertUpdateCountToResultSet(exact bool) *spannerpb.ResultSet { if exact { return &spannerpb.ResultSet{ Stats: &spannerpb.ResultSetStats{ RowCount: &spannerpb.ResultSetStats_RowCountExact{ RowCountExact: s.UpdateCount, }, }, } } return &spannerpb.ResultSet{ Stats: &spannerpb.ResultSetStats{ RowCount: &spannerpb.ResultSetStats_RowCountLowerBound{ RowCountLowerBound: s.UpdateCount, }, }, } } // SimulatedExecutionTime represents the time the execution of a method // should take, and any errors that should be returned by the method. type SimulatedExecutionTime struct { MinimumExecutionTime time.Duration RandomExecutionTime time.Duration Errors []error // Keep error after execution. The error will continue to be returned until // it is cleared. KeepError bool } // InMemSpannerServer contains the SpannerServer interface plus a couple // of specific methods for adding mocked results and resetting the server. type InMemSpannerServer interface { spannerpb.SpannerServer // Stops this server. Stop() // Resets the in-mem server to its default state, deleting all sessions and // transactions that have been created on the server. Mocked results are // not deleted. Reset() // Sets an error that will be returned by the next server call. The server // call will also automatically clear the error. SetError(err error) // Puts a mocked result on the server for a specific sql statement. The // server does not parse the SQL string in any way, it is merely used as // a key to the mocked result. The result will be used for all methods that // expect a SQL statement, including (batch) DML methods. PutStatementResult(sql string, result *StatementResult) error // Adds a PartialResultSetExecutionTime to the server that should be returned // for the specified SQL string. AddPartialResultSetError(sql string, err PartialResultSetExecutionTime) // Removes a mocked result on the server for a specific sql statement. RemoveStatementResult(sql string) // Aborts the specified transaction . This method can be used to test // transaction retry logic. AbortTransaction(id []byte) // Puts a simulated execution time for one of the Spanner methods. PutExecutionTime(method string, executionTime SimulatedExecutionTime) // Freeze stalls all requests. Freeze() // Unfreeze restores processing requests. Unfreeze() TotalSessionsCreated() uint TotalSessionsDeleted() uint SetMaxSessionsReturnedByServerPerBatchRequest(sessionCount int32) SetMaxSessionsReturnedByServerInTotal(sessionCount int32) ReceivedRequests() chan interface{} DumpSessions() map[string]bool ClearPings() DumpPings() []string } type inMemSpannerServer struct { // Embed for forward compatibility. // Tests will keep working if more methods are added // in the future. spannerpb.SpannerServer mu sync.Mutex // Set to true when this server been stopped. This is the end state of a // server, a stopped server cannot be restarted. stopped bool // If set, all calls return this error. err error // The mock server creates session IDs using this counter. sessionCounter uint64 // The sessions that have been created on this mock server. sessions map[string]*spannerpb.Session // Last use times per session. sessionLastUseTime map[string]time.Time // The mock server creates transaction IDs per session using these // counters. transactionCounters map[string]*uint64 // The transactions that have been created on this mock server. transactions map[string]*spannerpb.Transaction // The transactions that have been (manually) aborted on the server. abortedTransactions map[string]bool // The transactions that are marked as PartitionedDMLTransaction partitionedDmlTransactions map[string]bool // The mocked results for this server. statementResults map[string]*StatementResult // The simulated execution times per method. executionTimes map[string]*SimulatedExecutionTime // The simulated errors for partial result sets partialResultSetErrors map[string][]*PartialResultSetExecutionTime totalSessionsCreated uint totalSessionsDeleted uint // The maximum number of sessions that will be created per batch request. maxSessionsReturnedByServerPerBatchRequest int32 maxSessionsReturnedByServerInTotal int32 receivedRequests chan interface{} // Session ping history. pings []string // Server will stall on any requests. freezed chan struct{} } // NewInMemSpannerServer creates a new in-mem test server. func NewInMemSpannerServer() InMemSpannerServer { res := &inMemSpannerServer{} res.initDefaults() res.statementResults = make(map[string]*StatementResult) res.executionTimes = make(map[string]*SimulatedExecutionTime) res.partialResultSetErrors = make(map[string][]*PartialResultSetExecutionTime) res.receivedRequests = make(chan interface{}, 1000000) // Produce a closed channel, so the default action of ready is to not block. res.Freeze() res.Unfreeze() return res } func (s *inMemSpannerServer) Stop() { s.mu.Lock() defer s.mu.Unlock() s.stopped = true close(s.receivedRequests) } // Resets the test server to its initial state, deleting all sessions and // transactions that have been created on the server. This method will not // remove mocked results. func (s *inMemSpannerServer) Reset() { s.mu.Lock() defer s.mu.Unlock() close(s.receivedRequests) s.receivedRequests = make(chan interface{}, 1000000) s.initDefaults() } func (s *inMemSpannerServer) SetError(err error) { s.mu.Lock() defer s.mu.Unlock() s.err = err } // Registers a mocked result for a SQL statement on the server. func (s *inMemSpannerServer) PutStatementResult(sql string, result *StatementResult) error { s.mu.Lock() defer s.mu.Unlock() s.statementResults[sql] = result return nil } func (s *inMemSpannerServer) RemoveStatementResult(sql string) { s.mu.Lock() defer s.mu.Unlock() delete(s.statementResults, sql) } func (s *inMemSpannerServer) AbortTransaction(id []byte) { s.mu.Lock() defer s.mu.Unlock() s.abortedTransactions[string(id)] = true } func (s *inMemSpannerServer) PutExecutionTime(method string, executionTime SimulatedExecutionTime) { s.mu.Lock() defer s.mu.Unlock() s.executionTimes[method] = &executionTime } func (s *inMemSpannerServer) AddPartialResultSetError(sql string, partialResultSetError PartialResultSetExecutionTime) { s.mu.Lock() defer s.mu.Unlock() s.partialResultSetErrors[sql] = append(s.partialResultSetErrors[sql], &partialResultSetError) } // Freeze stalls all requests. func (s *inMemSpannerServer) Freeze() { s.mu.Lock() defer s.mu.Unlock() s.freezed = make(chan struct{}) } // Unfreeze restores processing requests. func (s *inMemSpannerServer) Unfreeze() { s.mu.Lock() defer s.mu.Unlock() close(s.freezed) } // ready checks conditions before executing requests func (s *inMemSpannerServer) ready() { s.mu.Lock() freezed := s.freezed s.mu.Unlock() // check if server should be freezed <-freezed } func (s *inMemSpannerServer) TotalSessionsCreated() uint { s.mu.Lock() defer s.mu.Unlock() return s.totalSessionsCreated } func (s *inMemSpannerServer) TotalSessionsDeleted() uint { s.mu.Lock() defer s.mu.Unlock() return s.totalSessionsDeleted } func (s *inMemSpannerServer) SetMaxSessionsReturnedByServerPerBatchRequest(sessionCount int32) { s.mu.Lock() defer s.mu.Unlock() s.maxSessionsReturnedByServerPerBatchRequest = sessionCount } func (s *inMemSpannerServer) SetMaxSessionsReturnedByServerInTotal(sessionCount int32) { s.mu.Lock() defer s.mu.Unlock() s.maxSessionsReturnedByServerInTotal = sessionCount } func (s *inMemSpannerServer) ReceivedRequests() chan interface{} { return s.receivedRequests } // ClearPings clears the ping history from the server. func (s *inMemSpannerServer) ClearPings() { s.mu.Lock() defer s.mu.Unlock() s.pings = nil } // DumpPings dumps the ping history. func (s *inMemSpannerServer) DumpPings() []string { s.mu.Lock() defer s.mu.Unlock() return append([]string(nil), s.pings...) } // DumpSessions dumps the internal session table. func (s *inMemSpannerServer) DumpSessions() map[string]bool { s.mu.Lock() defer s.mu.Unlock() st := map[string]bool{} for s := range s.sessions { st[s] = true } return st } func (s *inMemSpannerServer) initDefaults() { s.sessionCounter = 0 s.maxSessionsReturnedByServerPerBatchRequest = 100 s.sessions = make(map[string]*spannerpb.Session) s.sessionLastUseTime = make(map[string]time.Time) s.transactions = make(map[string]*spannerpb.Transaction) s.abortedTransactions = make(map[string]bool) s.partitionedDmlTransactions = make(map[string]bool) s.transactionCounters = make(map[string]*uint64) } func (s *inMemSpannerServer) generateSessionNameLocked(database string) string { s.sessionCounter++ return fmt.Sprintf("%s/sessions/%d", database, s.sessionCounter) } func (s *inMemSpannerServer) findSession(name string) (*spannerpb.Session, error) { s.mu.Lock() defer s.mu.Unlock() session := s.sessions[name] if session == nil { return nil, newSessionNotFoundError(name) } return session, nil } // sessionResourceType is the type name of Spanner sessions. const sessionResourceType = "type.googleapis.com/google.spanner.v1.Session" func newSessionNotFoundError(name string) error { s := gstatus.Newf(codes.NotFound, "Session not found: Session with id %s not found", name) s, _ = s.WithDetails(&errdetails.ResourceInfo{ResourceType: sessionResourceType, ResourceName: name}) return s.Err() } func (s *inMemSpannerServer) updateSessionLastUseTime(session string) { s.mu.Lock() defer s.mu.Unlock() s.sessionLastUseTime[session] = time.Now() } func getCurrentTimestamp() *timestamp.Timestamp { t := time.Now() return ×tamp.Timestamp{Seconds: t.Unix(), Nanos: int32(t.Nanosecond())} } // Gets the transaction id from the transaction selector. If the selector // specifies that a new transaction should be started, this method will start // a new transaction and return the id of that transaction. func (s *inMemSpannerServer) getTransactionID(session *spannerpb.Session, txSelector *spannerpb.TransactionSelector) []byte { var res []byte if txSelector.GetBegin() != nil { // Start a new transaction. res = s.beginTransaction(session, txSelector.GetBegin()).Id } else if txSelector.GetId() != nil { res = txSelector.GetId() } return res } func (s *inMemSpannerServer) generateTransactionName(session string) string { s.mu.Lock() defer s.mu.Unlock() counter, ok := s.transactionCounters[session] if !ok { counter = new(uint64) s.transactionCounters[session] = counter } *counter++ return fmt.Sprintf("%s/transactions/%d", session, *counter) } func (s *inMemSpannerServer) beginTransaction(session *spannerpb.Session, options *spannerpb.TransactionOptions) *spannerpb.Transaction { id := s.generateTransactionName(session.Name) res := &spannerpb.Transaction{ Id: []byte(id), ReadTimestamp: getCurrentTimestamp(), } s.mu.Lock() s.transactions[id] = res s.partitionedDmlTransactions[id] = options.GetPartitionedDml() != nil s.mu.Unlock() return res } func (s *inMemSpannerServer) getTransactionByID(id []byte) (*spannerpb.Transaction, error) { s.mu.Lock() defer s.mu.Unlock() tx, ok := s.transactions[string(id)] if !ok { return nil, gstatus.Error(codes.NotFound, "Transaction not found") } aborted, ok := s.abortedTransactions[string(id)] if ok && aborted { return nil, newAbortedErrorWithMinimalRetryDelay() } return tx, nil } func newAbortedErrorWithMinimalRetryDelay() error { st := gstatus.New(codes.Aborted, "Transaction has been aborted") retry := &errdetails.RetryInfo{ RetryDelay: ptypes.DurationProto(time.Nanosecond), } st, _ = st.WithDetails(retry) return st.Err() } func (s *inMemSpannerServer) removeTransaction(tx *spannerpb.Transaction) { s.mu.Lock() defer s.mu.Unlock() delete(s.transactions, string(tx.Id)) delete(s.partitionedDmlTransactions, string(tx.Id)) } func (s *inMemSpannerServer) getStatementResult(sql string) (*StatementResult, error) { s.mu.Lock() defer s.mu.Unlock() result, ok := s.statementResults[sql] if !ok { return nil, gstatus.Error(codes.Internal, fmt.Sprintf("No result found for statement %v", sql)) } return result, nil } func (s *inMemSpannerServer) simulateExecutionTime(method string, req interface{}) error { s.mu.Lock() if s.stopped { s.mu.Unlock() return gstatus.Error(codes.Unavailable, "server has been stopped") } s.receivedRequests <- req s.mu.Unlock() s.ready() s.mu.Lock() if s.err != nil { err := s.err s.err = nil s.mu.Unlock() return err } executionTime, ok := s.executionTimes[method] s.mu.Unlock() if ok { var randTime int64 if executionTime.RandomExecutionTime > 0 { randTime = rand.Int63n(int64(executionTime.RandomExecutionTime)) } totalExecutionTime := time.Duration(int64(executionTime.MinimumExecutionTime) + randTime) <-time.After(totalExecutionTime) s.mu.Lock() if executionTime.Errors != nil && len(executionTime.Errors) > 0 { err := executionTime.Errors[0] if !executionTime.KeepError { executionTime.Errors = executionTime.Errors[1:] } s.mu.Unlock() return err } s.mu.Unlock() } return nil } func (s *inMemSpannerServer) CreateSession(ctx context.Context, req *spannerpb.CreateSessionRequest) (*spannerpb.Session, error) { if err := s.simulateExecutionTime(MethodCreateSession, req); err != nil { return nil, err } if req.Database == "" { return nil, gstatus.Error(codes.InvalidArgument, "Missing database") } s.mu.Lock() defer s.mu.Unlock() if s.maxSessionsReturnedByServerInTotal > int32(0) && int32(len(s.sessions)) == s.maxSessionsReturnedByServerInTotal { return nil, gstatus.Error(codes.ResourceExhausted, "No more sessions available") } sessionName := s.generateSessionNameLocked(req.Database) ts := getCurrentTimestamp() session := &spannerpb.Session{Name: sessionName, CreateTime: ts, ApproximateLastUseTime: ts} s.totalSessionsCreated++ s.sessions[sessionName] = session return session, nil } func (s *inMemSpannerServer) BatchCreateSessions(ctx context.Context, req *spannerpb.BatchCreateSessionsRequest) (*spannerpb.BatchCreateSessionsResponse, error) { if err := s.simulateExecutionTime(MethodBatchCreateSession, req); err != nil { return nil, err } if req.Database == "" { return nil, gstatus.Error(codes.InvalidArgument, "Missing database") } if req.SessionCount <= 0 { return nil, gstatus.Error(codes.InvalidArgument, "Session count must be >= 0") } sessionsToCreate := req.SessionCount s.mu.Lock() defer s.mu.Unlock() if s.maxSessionsReturnedByServerInTotal > int32(0) && int32(len(s.sessions)) >= s.maxSessionsReturnedByServerInTotal { return nil, gstatus.Error(codes.ResourceExhausted, "No more sessions available") } if sessionsToCreate > s.maxSessionsReturnedByServerPerBatchRequest { sessionsToCreate = s.maxSessionsReturnedByServerPerBatchRequest } if s.maxSessionsReturnedByServerInTotal > int32(0) && (sessionsToCreate+int32(len(s.sessions))) > s.maxSessionsReturnedByServerInTotal { sessionsToCreate = s.maxSessionsReturnedByServerInTotal - int32(len(s.sessions)) } sessions := make([]*spannerpb.Session, sessionsToCreate) for i := int32(0); i < sessionsToCreate; i++ { sessionName := s.generateSessionNameLocked(req.Database) ts := getCurrentTimestamp() sessions[i] = &spannerpb.Session{Name: sessionName, CreateTime: ts, ApproximateLastUseTime: ts} s.totalSessionsCreated++ s.sessions[sessionName] = sessions[i] } return &spannerpb.BatchCreateSessionsResponse{Session: sessions}, nil } func (s *inMemSpannerServer) GetSession(ctx context.Context, req *spannerpb.GetSessionRequest) (*spannerpb.Session, error) { if err := s.simulateExecutionTime(MethodGetSession, req); err != nil { return nil, err } s.mu.Lock() s.pings = append(s.pings, req.Name) s.mu.Unlock() if req.Name == "" { return nil, gstatus.Error(codes.InvalidArgument, "Missing session name") } session, err := s.findSession(req.Name) if err != nil { return nil, err } return session, nil } func (s *inMemSpannerServer) ListSessions(ctx context.Context, req *spannerpb.ListSessionsRequest) (*spannerpb.ListSessionsResponse, error) { s.mu.Lock() if s.stopped { s.mu.Unlock() return nil, gstatus.Error(codes.Unavailable, "server has been stopped") } s.receivedRequests <- req s.mu.Unlock() if req.Database == "" { return nil, gstatus.Error(codes.InvalidArgument, "Missing database") } expectedSessionName := req.Database + "/sessions/" var sessions []*spannerpb.Session s.mu.Lock() for _, session := range s.sessions { if strings.Index(session.Name, expectedSessionName) == 0 { sessions = append(sessions, session) } } s.mu.Unlock() sort.Slice(sessions[:], func(i, j int) bool { return sessions[i].Name < sessions[j].Name }) res := &spannerpb.ListSessionsResponse{Sessions: sessions} return res, nil } func (s *inMemSpannerServer) DeleteSession(ctx context.Context, req *spannerpb.DeleteSessionRequest) (*emptypb.Empty, error) { if err := s.simulateExecutionTime(MethodDeleteSession, req); err != nil { return nil, err } if req.Name == "" { return nil, gstatus.Error(codes.InvalidArgument, "Missing session name") } if _, err := s.findSession(req.Name); err != nil { return nil, err } s.mu.Lock() defer s.mu.Unlock() s.totalSessionsDeleted++ delete(s.sessions, req.Name) return &emptypb.Empty{}, nil } func (s *inMemSpannerServer) ExecuteSql(ctx context.Context, req *spannerpb.ExecuteSqlRequest) (*spannerpb.ResultSet, error) { if err := s.simulateExecutionTime(MethodExecuteSql, req); err != nil { return nil, err } if req.Session == "" { return nil, gstatus.Error(codes.InvalidArgument, "Missing session name") } session, err := s.findSession(req.Session) if err != nil { return nil, err } var id []byte s.updateSessionLastUseTime(session.Name) if id = s.getTransactionID(session, req.Transaction); id != nil { _, err = s.getTransactionByID(id) if err != nil { return nil, err } } statementResult, err := s.getStatementResult(req.Sql) if err != nil { return nil, err } s.mu.Lock() isPartitionedDml := s.partitionedDmlTransactions[string(id)] s.mu.Unlock() switch statementResult.Type { case StatementResultError: return nil, statementResult.Err case StatementResultResultSet: return statementResult.ResultSet, nil case StatementResultUpdateCount: return statementResult.convertUpdateCountToResultSet(!isPartitionedDml), nil } return nil, gstatus.Error(codes.Internal, "Unknown result type") } func (s *inMemSpannerServer) ExecuteStreamingSql(req *spannerpb.ExecuteSqlRequest, stream spannerpb.Spanner_ExecuteStreamingSqlServer) error { if err := s.simulateExecutionTime(MethodExecuteStreamingSql, req); err != nil { return err } if req.Session == "" { return gstatus.Error(codes.InvalidArgument, "Missing session name") } session, err := s.findSession(req.Session) if err != nil { return err } s.updateSessionLastUseTime(session.Name) var id []byte if id = s.getTransactionID(session, req.Transaction); id != nil { _, err = s.getTransactionByID(id) if err != nil { return err } } statementResult, err := s.getStatementResult(req.Sql) if err != nil { return err } s.mu.Lock() isPartitionedDml := s.partitionedDmlTransactions[string(id)] s.mu.Unlock() switch statementResult.Type { case StatementResultError: return statementResult.Err case StatementResultResultSet: parts, err := statementResult.toPartialResultSets(req.ResumeToken) if err != nil { return err } var nextPartialResultSetError *PartialResultSetExecutionTime s.mu.Lock() pErrors := s.partialResultSetErrors[req.Sql] if len(pErrors) > 0 { nextPartialResultSetError = pErrors[0] s.partialResultSetErrors[req.Sql] = pErrors[1:] } s.mu.Unlock() for _, part := range parts { if nextPartialResultSetError != nil && bytes.Equal(part.ResumeToken, nextPartialResultSetError.ResumeToken) { if nextPartialResultSetError.ExecutionTime > 0 { <-time.After(nextPartialResultSetError.ExecutionTime) } if nextPartialResultSetError.Err != nil { return nextPartialResultSetError.Err } } if err := stream.Send(part); err != nil { return err } } return nil case StatementResultUpdateCount: part := statementResult.updateCountToPartialResultSet(!isPartitionedDml) if err := stream.Send(part); err != nil { return err } return nil } return gstatus.Error(codes.Internal, "Unknown result type") } func (s *inMemSpannerServer) ExecuteBatchDml(ctx context.Context, req *spannerpb.ExecuteBatchDmlRequest) (*spannerpb.ExecuteBatchDmlResponse, error) { if err := s.simulateExecutionTime(MethodExecuteBatchDml, req); err != nil { return nil, err } if req.Session == "" { return nil, gstatus.Error(codes.InvalidArgument, "Missing session name") } session, err := s.findSession(req.Session) if err != nil { return nil, err } s.updateSessionLastUseTime(session.Name) var id []byte if id = s.getTransactionID(session, req.Transaction); id != nil { _, err = s.getTransactionByID(id) if err != nil { return nil, err } } s.mu.Lock() isPartitionedDml := s.partitionedDmlTransactions[string(id)] s.mu.Unlock() resp := &spannerpb.ExecuteBatchDmlResponse{} resp.ResultSets = make([]*spannerpb.ResultSet, len(req.Statements)) for idx, batchStatement := range req.Statements { statementResult, err := s.getStatementResult(batchStatement.Sql) if err != nil { return nil, err } switch statementResult.Type { case StatementResultError: resp.Status = &status.Status{Code: int32(codes.Unknown)} case StatementResultResultSet: return nil, gstatus.Error(codes.InvalidArgument, fmt.Sprintf("Not an update statement: %v", batchStatement.Sql)) case StatementResultUpdateCount: resp.ResultSets[idx] = statementResult.convertUpdateCountToResultSet(!isPartitionedDml) resp.Status = &status.Status{Code: int32(codes.OK)} } } return resp, nil } func (s *inMemSpannerServer) Read(ctx context.Context, req *spannerpb.ReadRequest) (*spannerpb.ResultSet, error) { s.mu.Lock() if s.stopped { s.mu.Unlock() return nil, gstatus.Error(codes.Unavailable, "server has been stopped") } s.receivedRequests <- req s.mu.Unlock() return nil, gstatus.Error(codes.Unimplemented, "Method not yet implemented") } func (s *inMemSpannerServer) StreamingRead(req *spannerpb.ReadRequest, stream spannerpb.Spanner_StreamingReadServer) error { s.mu.Lock() if s.stopped { s.mu.Unlock() return gstatus.Error(codes.Unavailable, "server has been stopped") } s.receivedRequests <- req s.mu.Unlock() return gstatus.Error(codes.Unimplemented, "Method not yet implemented") } func (s *inMemSpannerServer) BeginTransaction(ctx context.Context, req *spannerpb.BeginTransactionRequest) (*spannerpb.Transaction, error) { if err := s.simulateExecutionTime(MethodBeginTransaction, req); err != nil { return nil, err } if req.Session == "" { return nil, gstatus.Error(codes.InvalidArgument, "Missing session name") } session, err := s.findSession(req.Session) if err != nil { return nil, err } s.updateSessionLastUseTime(session.Name) tx := s.beginTransaction(session, req.Options) return tx, nil } func (s *inMemSpannerServer) Commit(ctx context.Context, req *spannerpb.CommitRequest) (*spannerpb.CommitResponse, error) { if err := s.simulateExecutionTime(MethodCommitTransaction, req); err != nil { return nil, err } if req.Session == "" { return nil, gstatus.Error(codes.InvalidArgument, "Missing session name") } session, err := s.findSession(req.Session) if err != nil { return nil, err } s.updateSessionLastUseTime(session.Name) var tx *spannerpb.Transaction if req.GetSingleUseTransaction() != nil { tx = s.beginTransaction(session, req.GetSingleUseTransaction()) } else if req.GetTransactionId() != nil { tx, err = s.getTransactionByID(req.GetTransactionId()) if err != nil { return nil, err } } else { return nil, gstatus.Error(codes.InvalidArgument, "Missing transaction in commit request") } s.removeTransaction(tx) return &spannerpb.CommitResponse{CommitTimestamp: getCurrentTimestamp()}, nil } func (s *inMemSpannerServer) Rollback(ctx context.Context, req *spannerpb.RollbackRequest) (*emptypb.Empty, error) { s.mu.Lock() if s.stopped { s.mu.Unlock() return nil, gstatus.Error(codes.Unavailable, "server has been stopped") } s.receivedRequests <- req s.mu.Unlock() if req.Session == "" { return nil, gstatus.Error(codes.InvalidArgument, "Missing session name") } session, err := s.findSession(req.Session) if err != nil { return nil, err } s.updateSessionLastUseTime(session.Name) tx, err := s.getTransactionByID(req.TransactionId) if err != nil { return nil, err } s.removeTransaction(tx) return &emptypb.Empty{}, nil } func (s *inMemSpannerServer) PartitionQuery(ctx context.Context, req *spannerpb.PartitionQueryRequest) (*spannerpb.PartitionResponse, error) { s.mu.Lock() if s.stopped { s.mu.Unlock() return nil, gstatus.Error(codes.Unavailable, "server has been stopped") } s.receivedRequests <- req s.mu.Unlock() if req.Session == "" { return nil, gstatus.Error(codes.InvalidArgument, "Missing session name") } session, err := s.findSession(req.Session) if err != nil { return nil, err } var id []byte var tx *spannerpb.Transaction s.updateSessionLastUseTime(session.Name) if id = s.getTransactionID(session, req.Transaction); id != nil { tx, err = s.getTransactionByID(id) if err != nil { return nil, err } } var partitions []*spannerpb.Partition for i := int64(0); i < req.PartitionOptions.MaxPartitions; i++ { token := make([]byte, 10) _, err := rand.Read(token) if err != nil { return nil, gstatus.Error(codes.Internal, "failed to generate random partition token") } partitions = append(partitions, &spannerpb.Partition{PartitionToken: token}) } return &spannerpb.PartitionResponse{ Partitions: partitions, Transaction: tx, }, nil } func (s *inMemSpannerServer) PartitionRead(ctx context.Context, req *spannerpb.PartitionReadRequest) (*spannerpb.PartitionResponse, error) { s.mu.Lock() if s.stopped { s.mu.Unlock() return nil, gstatus.Error(codes.Unavailable, "server has been stopped") } s.receivedRequests <- req s.mu.Unlock() return nil, gstatus.Error(codes.Unimplemented, "Method not yet implemented") }