1// Copyright 2019 Google LLC
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     https://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15package testutil
16
17import (
18	"bytes"
19	"context"
20	"fmt"
21	"math/rand"
22	"sort"
23	"strings"
24	"sync"
25	"time"
26
27	emptypb "github.com/golang/protobuf/ptypes/empty"
28	structpb "github.com/golang/protobuf/ptypes/struct"
29	"github.com/golang/protobuf/ptypes/timestamp"
30	"google.golang.org/genproto/googleapis/rpc/status"
31	spannerpb "google.golang.org/genproto/googleapis/spanner/v1"
32	"google.golang.org/grpc/codes"
33	gstatus "google.golang.org/grpc/status"
34)
35
36// StatementResultType indicates the type of result returned by a SQL
37// statement.
38type StatementResultType int
39
40const (
41	// StatementResultError indicates that the sql statement returns an error.
42	StatementResultError StatementResultType = 0
43	// StatementResultResultSet indicates that the sql statement returns a
44	// result set.
45	StatementResultResultSet StatementResultType = 1
46	// StatementResultUpdateCount indicates that the sql statement returns an
47	// update count.
48	StatementResultUpdateCount StatementResultType = 2
49	// MaxRowsPerPartialResultSet is the maximum number of rows returned in
50	// each PartialResultSet. This number is deliberately set to a low value to
51	// ensure that most queries return more than one PartialResultSet.
52	MaxRowsPerPartialResultSet = 1
53)
54
55// The method names that can be used to register execution times and errors.
56const (
57	MethodBeginTransaction    string = "BEGIN_TRANSACTION"
58	MethodCommitTransaction   string = "COMMIT_TRANSACTION"
59	MethodBatchCreateSession  string = "BATCH_CREATE_SESSION"
60	MethodCreateSession       string = "CREATE_SESSION"
61	MethodDeleteSession       string = "DELETE_SESSION"
62	MethodGetSession          string = "GET_SESSION"
63	MethodExecuteSql          string = "EXECUTE_SQL"
64	MethodExecuteStreamingSql string = "EXECUTE_STREAMING_SQL"
65)
66
67// StatementResult represents a mocked result on the test server. The result is
68// either of: a ResultSet, an update count or an error.
69type StatementResult struct {
70	Type        StatementResultType
71	Err         error
72	ResultSet   *spannerpb.ResultSet
73	UpdateCount int64
74}
75
76// PartialResultSetExecutionTime represents execution times and errors that
77// should be used when a PartialResult at the specified resume token is to
78// be returned.
79type PartialResultSetExecutionTime struct {
80	ResumeToken   []byte
81	ExecutionTime time.Duration
82	Err           error
83}
84
85// Converts a ResultSet to a PartialResultSet. This method is used to convert
86// a mocked result to a PartialResultSet when one of the streaming methods are
87// called.
88func (s *StatementResult) toPartialResultSets(resumeToken []byte) (result []*spannerpb.PartialResultSet, err error) {
89	var startIndex uint64
90	if len(resumeToken) > 0 {
91		if startIndex, err = DecodeResumeToken(resumeToken); err != nil {
92			return nil, err
93		}
94	}
95
96	totalRows := uint64(len(s.ResultSet.Rows))
97	for {
98		rowCount := min(totalRows-startIndex, uint64(MaxRowsPerPartialResultSet))
99		rows := s.ResultSet.Rows[startIndex : startIndex+rowCount]
100		values := make([]*structpb.Value,
101			len(rows)*len(s.ResultSet.Metadata.RowType.Fields))
102		var idx int
103		for _, row := range rows {
104			for colIdx := range s.ResultSet.Metadata.RowType.Fields {
105				values[idx] = row.Values[colIdx]
106				idx++
107			}
108		}
109		result = append(result, &spannerpb.PartialResultSet{
110			Metadata:    s.ResultSet.Metadata,
111			Values:      values,
112			ResumeToken: EncodeResumeToken(startIndex + rowCount),
113		})
114		startIndex += rowCount
115		if startIndex == totalRows {
116			break
117		}
118	}
119	return result, nil
120}
121
122func min(x, y uint64) uint64 {
123	if x > y {
124		return y
125	}
126	return x
127}
128
129func (s *StatementResult) updateCountToPartialResultSet(exact bool) *spannerpb.PartialResultSet {
130	return &spannerpb.PartialResultSet{
131		Stats: s.convertUpdateCountToResultSet(exact).Stats,
132	}
133}
134
135// Converts an update count to a ResultSet, as DML statements also return the
136// update count as the statistics of a ResultSet.
137func (s *StatementResult) convertUpdateCountToResultSet(exact bool) *spannerpb.ResultSet {
138	if exact {
139		return &spannerpb.ResultSet{
140			Stats: &spannerpb.ResultSetStats{
141				RowCount: &spannerpb.ResultSetStats_RowCountExact{
142					RowCountExact: s.UpdateCount,
143				},
144			},
145		}
146	}
147	return &spannerpb.ResultSet{
148		Stats: &spannerpb.ResultSetStats{
149			RowCount: &spannerpb.ResultSetStats_RowCountLowerBound{
150				RowCountLowerBound: s.UpdateCount,
151			},
152		},
153	}
154}
155
156// SimulatedExecutionTime represents the time the execution of a method
157// should take, and any errors that should be returned by the method.
158type SimulatedExecutionTime struct {
159	MinimumExecutionTime time.Duration
160	RandomExecutionTime  time.Duration
161	Errors               []error
162	// Keep error after execution. The error will continue to be returned until
163	// it is cleared.
164	KeepError bool
165}
166
167// InMemSpannerServer contains the SpannerServer interface plus a couple
168// of specific methods for adding mocked results and resetting the server.
169type InMemSpannerServer interface {
170	spannerpb.SpannerServer
171
172	// Stops this server.
173	Stop()
174
175	// Resets the in-mem server to its default state, deleting all sessions and
176	// transactions that have been created on the server. Mocked results are
177	// not deleted.
178	Reset()
179
180	// Sets an error that will be returned by the next server call. The server
181	// call will also automatically clear the error.
182	SetError(err error)
183
184	// Puts a mocked result on the server for a specific sql statement. The
185	// server does not parse the SQL string in any way, it is merely used as
186	// a key to the mocked result. The result will be used for all methods that
187	// expect a SQL statement, including (batch) DML methods.
188	PutStatementResult(sql string, result *StatementResult) error
189
190	// Adds a PartialResultSetExecutionTime to the server that should be returned
191	// for the specified SQL string.
192	AddPartialResultSetError(sql string, err PartialResultSetExecutionTime)
193
194	// Removes a mocked result on the server for a specific sql statement.
195	RemoveStatementResult(sql string)
196
197	// Aborts the specified transaction . This method can be used to test
198	// transaction retry logic.
199	AbortTransaction(id []byte)
200
201	// Puts a simulated execution time for one of the Spanner methods.
202	PutExecutionTime(method string, executionTime SimulatedExecutionTime)
203	// Freeze stalls all requests.
204	Freeze()
205	// Unfreeze restores processing requests.
206	Unfreeze()
207
208	TotalSessionsCreated() uint
209	TotalSessionsDeleted() uint
210	SetMaxSessionsReturnedByServerPerBatchRequest(sessionCount int32)
211	SetMaxSessionsReturnedByServerInTotal(sessionCount int32)
212
213	ReceivedRequests() chan interface{}
214	DumpSessions() map[string]bool
215	ClearPings()
216	DumpPings() []string
217}
218
219type inMemSpannerServer struct {
220	// Embed for forward compatibility.
221	// Tests will keep working if more methods are added
222	// in the future.
223	spannerpb.SpannerServer
224
225	mu sync.Mutex
226	// Set to true when this server been stopped. This is the end state of a
227	// server, a stopped server cannot be restarted.
228	stopped bool
229	// If set, all calls return this error.
230	err error
231	// The mock server creates session IDs using this counter.
232	sessionCounter uint64
233	// The sessions that have been created on this mock server.
234	sessions map[string]*spannerpb.Session
235	// Last use times per session.
236	sessionLastUseTime map[string]time.Time
237	// The mock server creates transaction IDs per session using these
238	// counters.
239	transactionCounters map[string]*uint64
240	// The transactions that have been created on this mock server.
241	transactions map[string]*spannerpb.Transaction
242	// The transactions that have been (manually) aborted on the server.
243	abortedTransactions map[string]bool
244	// The transactions that are marked as PartitionedDMLTransaction
245	partitionedDmlTransactions map[string]bool
246	// The mocked results for this server.
247	statementResults map[string]*StatementResult
248	// The simulated execution times per method.
249	executionTimes map[string]*SimulatedExecutionTime
250	// The simulated errors for partial result sets
251	partialResultSetErrors map[string][]*PartialResultSetExecutionTime
252
253	totalSessionsCreated uint
254	totalSessionsDeleted uint
255	// The maximum number of sessions that will be created per batch request.
256	maxSessionsReturnedByServerPerBatchRequest int32
257	maxSessionsReturnedByServerInTotal         int32
258	receivedRequests                           chan interface{}
259	// Session ping history.
260	pings []string
261
262	// Server will stall on any requests.
263	freezed chan struct{}
264}
265
266// NewInMemSpannerServer creates a new in-mem test server.
267func NewInMemSpannerServer() InMemSpannerServer {
268	res := &inMemSpannerServer{}
269	res.initDefaults()
270	res.statementResults = make(map[string]*StatementResult)
271	res.executionTimes = make(map[string]*SimulatedExecutionTime)
272	res.partialResultSetErrors = make(map[string][]*PartialResultSetExecutionTime)
273	res.receivedRequests = make(chan interface{}, 1000000)
274	// Produce a closed channel, so the default action of ready is to not block.
275	res.Freeze()
276	res.Unfreeze()
277	return res
278}
279
280func (s *inMemSpannerServer) Stop() {
281	s.mu.Lock()
282	defer s.mu.Unlock()
283	s.stopped = true
284	close(s.receivedRequests)
285}
286
287// Resets the test server to its initial state, deleting all sessions and
288// transactions that have been created on the server. This method will not
289// remove mocked results.
290func (s *inMemSpannerServer) Reset() {
291	s.mu.Lock()
292	defer s.mu.Unlock()
293	close(s.receivedRequests)
294	s.receivedRequests = make(chan interface{}, 1000000)
295	s.initDefaults()
296}
297
298func (s *inMemSpannerServer) SetError(err error) {
299	s.mu.Lock()
300	defer s.mu.Unlock()
301	s.err = err
302}
303
304// Registers a mocked result for a SQL statement on the server.
305func (s *inMemSpannerServer) PutStatementResult(sql string, result *StatementResult) error {
306	s.mu.Lock()
307	defer s.mu.Unlock()
308	s.statementResults[sql] = result
309	return nil
310}
311
312func (s *inMemSpannerServer) RemoveStatementResult(sql string) {
313	s.mu.Lock()
314	defer s.mu.Unlock()
315	delete(s.statementResults, sql)
316}
317
318func (s *inMemSpannerServer) AbortTransaction(id []byte) {
319	s.mu.Lock()
320	defer s.mu.Unlock()
321	s.abortedTransactions[string(id)] = true
322}
323
324func (s *inMemSpannerServer) PutExecutionTime(method string, executionTime SimulatedExecutionTime) {
325	s.mu.Lock()
326	defer s.mu.Unlock()
327	s.executionTimes[method] = &executionTime
328}
329
330func (s *inMemSpannerServer) AddPartialResultSetError(sql string, partialResultSetError PartialResultSetExecutionTime) {
331	s.mu.Lock()
332	defer s.mu.Unlock()
333	s.partialResultSetErrors[sql] = append(s.partialResultSetErrors[sql], &partialResultSetError)
334}
335
336// Freeze stalls all requests.
337func (s *inMemSpannerServer) Freeze() {
338	s.mu.Lock()
339	defer s.mu.Unlock()
340	s.freezed = make(chan struct{})
341}
342
343// Unfreeze restores processing requests.
344func (s *inMemSpannerServer) Unfreeze() {
345	s.mu.Lock()
346	defer s.mu.Unlock()
347	close(s.freezed)
348}
349
350// ready checks conditions before executing requests
351func (s *inMemSpannerServer) ready() {
352	s.mu.Lock()
353	freezed := s.freezed
354	s.mu.Unlock()
355	// check if server should be freezed
356	<-freezed
357}
358
359func (s *inMemSpannerServer) TotalSessionsCreated() uint {
360	s.mu.Lock()
361	defer s.mu.Unlock()
362	return s.totalSessionsCreated
363}
364
365func (s *inMemSpannerServer) TotalSessionsDeleted() uint {
366	s.mu.Lock()
367	defer s.mu.Unlock()
368	return s.totalSessionsDeleted
369}
370
371func (s *inMemSpannerServer) SetMaxSessionsReturnedByServerPerBatchRequest(sessionCount int32) {
372	s.mu.Lock()
373	defer s.mu.Unlock()
374	s.maxSessionsReturnedByServerPerBatchRequest = sessionCount
375}
376
377func (s *inMemSpannerServer) SetMaxSessionsReturnedByServerInTotal(sessionCount int32) {
378	s.mu.Lock()
379	defer s.mu.Unlock()
380	s.maxSessionsReturnedByServerInTotal = sessionCount
381}
382
383func (s *inMemSpannerServer) ReceivedRequests() chan interface{} {
384	return s.receivedRequests
385}
386
387// ClearPings clears the ping history from the server.
388func (s *inMemSpannerServer) ClearPings() {
389	s.mu.Lock()
390	defer s.mu.Unlock()
391	s.pings = nil
392}
393
394// DumpPings dumps the ping history.
395func (s *inMemSpannerServer) DumpPings() []string {
396	s.mu.Lock()
397	defer s.mu.Unlock()
398	return append([]string(nil), s.pings...)
399}
400
401// DumpSessions dumps the internal session table.
402func (s *inMemSpannerServer) DumpSessions() map[string]bool {
403	s.mu.Lock()
404	defer s.mu.Unlock()
405	st := map[string]bool{}
406	for s := range s.sessions {
407		st[s] = true
408	}
409	return st
410}
411
412func (s *inMemSpannerServer) initDefaults() {
413	s.sessionCounter = 0
414	s.maxSessionsReturnedByServerPerBatchRequest = 100
415	s.sessions = make(map[string]*spannerpb.Session)
416	s.sessionLastUseTime = make(map[string]time.Time)
417	s.transactions = make(map[string]*spannerpb.Transaction)
418	s.abortedTransactions = make(map[string]bool)
419	s.partitionedDmlTransactions = make(map[string]bool)
420	s.transactionCounters = make(map[string]*uint64)
421}
422
423func (s *inMemSpannerServer) generateSessionNameLocked(database string) string {
424	s.sessionCounter++
425	return fmt.Sprintf("%s/sessions/%d", database, s.sessionCounter)
426}
427
428func (s *inMemSpannerServer) findSession(name string) (*spannerpb.Session, error) {
429	s.mu.Lock()
430	defer s.mu.Unlock()
431	session := s.sessions[name]
432	if session == nil {
433		return nil, gstatus.Error(codes.NotFound, fmt.Sprintf("Session %s not found", name))
434	}
435	return session, nil
436}
437
438func (s *inMemSpannerServer) updateSessionLastUseTime(session string) {
439	s.mu.Lock()
440	defer s.mu.Unlock()
441	s.sessionLastUseTime[session] = time.Now()
442}
443
444func getCurrentTimestamp() *timestamp.Timestamp {
445	t := time.Now()
446	return &timestamp.Timestamp{Seconds: t.Unix(), Nanos: int32(t.Nanosecond())}
447}
448
449// Gets the transaction id from the transaction selector. If the selector
450// specifies that a new transaction should be started, this method will start
451// a new transaction and return the id of that transaction.
452func (s *inMemSpannerServer) getTransactionID(session *spannerpb.Session, txSelector *spannerpb.TransactionSelector) []byte {
453	var res []byte
454	if txSelector.GetBegin() != nil {
455		// Start a new transaction.
456		res = s.beginTransaction(session, txSelector.GetBegin()).Id
457	} else if txSelector.GetId() != nil {
458		res = txSelector.GetId()
459	}
460	return res
461}
462
463func (s *inMemSpannerServer) generateTransactionName(session string) string {
464	s.mu.Lock()
465	defer s.mu.Unlock()
466	counter, ok := s.transactionCounters[session]
467	if !ok {
468		counter = new(uint64)
469		s.transactionCounters[session] = counter
470	}
471	*counter++
472	return fmt.Sprintf("%s/transactions/%d", session, *counter)
473}
474
475func (s *inMemSpannerServer) beginTransaction(session *spannerpb.Session, options *spannerpb.TransactionOptions) *spannerpb.Transaction {
476	id := s.generateTransactionName(session.Name)
477	res := &spannerpb.Transaction{
478		Id:            []byte(id),
479		ReadTimestamp: getCurrentTimestamp(),
480	}
481	s.mu.Lock()
482	s.transactions[id] = res
483	s.partitionedDmlTransactions[id] = options.GetPartitionedDml() != nil
484	s.mu.Unlock()
485	return res
486}
487
488func (s *inMemSpannerServer) getTransactionByID(id []byte) (*spannerpb.Transaction, error) {
489	s.mu.Lock()
490	defer s.mu.Unlock()
491	tx, ok := s.transactions[string(id)]
492	if !ok {
493		return nil, gstatus.Error(codes.NotFound, "Transaction not found")
494	}
495	aborted, ok := s.abortedTransactions[string(id)]
496	if ok && aborted {
497		return nil, gstatus.Error(codes.Aborted, "Transaction has been aborted")
498	}
499	return tx, nil
500}
501
502func (s *inMemSpannerServer) removeTransaction(tx *spannerpb.Transaction) {
503	s.mu.Lock()
504	defer s.mu.Unlock()
505	delete(s.transactions, string(tx.Id))
506	delete(s.partitionedDmlTransactions, string(tx.Id))
507}
508
509func (s *inMemSpannerServer) getStatementResult(sql string) (*StatementResult, error) {
510	s.mu.Lock()
511	defer s.mu.Unlock()
512	result, ok := s.statementResults[sql]
513	if !ok {
514		return nil, gstatus.Error(codes.Internal, fmt.Sprintf("No result found for statement %v", sql))
515	}
516	return result, nil
517}
518
519func (s *inMemSpannerServer) simulateExecutionTime(method string, req interface{}) error {
520	s.mu.Lock()
521	if s.stopped {
522		s.mu.Unlock()
523		return gstatus.Error(codes.Unavailable, "server has been stopped")
524	}
525	s.receivedRequests <- req
526	s.mu.Unlock()
527	s.ready()
528	s.mu.Lock()
529	if s.err != nil {
530		err := s.err
531		s.err = nil
532		s.mu.Unlock()
533		return err
534	}
535	executionTime, ok := s.executionTimes[method]
536	s.mu.Unlock()
537	if ok {
538		var randTime int64
539		if executionTime.RandomExecutionTime > 0 {
540			randTime = rand.Int63n(int64(executionTime.RandomExecutionTime))
541		}
542		totalExecutionTime := time.Duration(int64(executionTime.MinimumExecutionTime) + randTime)
543		<-time.After(totalExecutionTime)
544		s.mu.Lock()
545		if executionTime.Errors != nil && len(executionTime.Errors) > 0 {
546			err := executionTime.Errors[0]
547			if !executionTime.KeepError {
548				executionTime.Errors = executionTime.Errors[1:]
549			}
550			s.mu.Unlock()
551			return err
552		}
553		s.mu.Unlock()
554	}
555	return nil
556}
557
558func (s *inMemSpannerServer) CreateSession(ctx context.Context, req *spannerpb.CreateSessionRequest) (*spannerpb.Session, error) {
559	if err := s.simulateExecutionTime(MethodCreateSession, req); err != nil {
560		return nil, err
561	}
562	if req.Database == "" {
563		return nil, gstatus.Error(codes.InvalidArgument, "Missing database")
564	}
565	s.mu.Lock()
566	defer s.mu.Unlock()
567	if s.maxSessionsReturnedByServerInTotal > int32(0) && int32(len(s.sessions)) == s.maxSessionsReturnedByServerInTotal {
568		return nil, gstatus.Error(codes.ResourceExhausted, "No more sessions available")
569	}
570	sessionName := s.generateSessionNameLocked(req.Database)
571	ts := getCurrentTimestamp()
572	session := &spannerpb.Session{Name: sessionName, CreateTime: ts, ApproximateLastUseTime: ts}
573	s.totalSessionsCreated++
574	s.sessions[sessionName] = session
575	return session, nil
576}
577
578func (s *inMemSpannerServer) BatchCreateSessions(ctx context.Context, req *spannerpb.BatchCreateSessionsRequest) (*spannerpb.BatchCreateSessionsResponse, error) {
579	if err := s.simulateExecutionTime(MethodBatchCreateSession, req); err != nil {
580		return nil, err
581	}
582	if req.Database == "" {
583		return nil, gstatus.Error(codes.InvalidArgument, "Missing database")
584	}
585	if req.SessionCount <= 0 {
586		return nil, gstatus.Error(codes.InvalidArgument, "Session count must be >= 0")
587	}
588	sessionsToCreate := req.SessionCount
589	s.mu.Lock()
590	defer s.mu.Unlock()
591	if s.maxSessionsReturnedByServerInTotal > int32(0) && int32(len(s.sessions)) >= s.maxSessionsReturnedByServerInTotal {
592		return nil, gstatus.Error(codes.ResourceExhausted, "No more sessions available")
593	}
594	if sessionsToCreate > s.maxSessionsReturnedByServerPerBatchRequest {
595		sessionsToCreate = s.maxSessionsReturnedByServerPerBatchRequest
596	}
597	if s.maxSessionsReturnedByServerInTotal > int32(0) && (sessionsToCreate+int32(len(s.sessions))) > s.maxSessionsReturnedByServerInTotal {
598		sessionsToCreate = s.maxSessionsReturnedByServerInTotal - int32(len(s.sessions))
599	}
600	sessions := make([]*spannerpb.Session, sessionsToCreate)
601	for i := int32(0); i < sessionsToCreate; i++ {
602		sessionName := s.generateSessionNameLocked(req.Database)
603		ts := getCurrentTimestamp()
604		sessions[i] = &spannerpb.Session{Name: sessionName, CreateTime: ts, ApproximateLastUseTime: ts}
605		s.totalSessionsCreated++
606		s.sessions[sessionName] = sessions[i]
607	}
608	return &spannerpb.BatchCreateSessionsResponse{Session: sessions}, nil
609}
610
611func (s *inMemSpannerServer) GetSession(ctx context.Context, req *spannerpb.GetSessionRequest) (*spannerpb.Session, error) {
612	if err := s.simulateExecutionTime(MethodGetSession, req); err != nil {
613		return nil, err
614	}
615	s.mu.Lock()
616	s.pings = append(s.pings, req.Name)
617	s.mu.Unlock()
618	if req.Name == "" {
619		return nil, gstatus.Error(codes.InvalidArgument, "Missing session name")
620	}
621	session, err := s.findSession(req.Name)
622	if err != nil {
623		return nil, err
624	}
625	return session, nil
626}
627
628func (s *inMemSpannerServer) ListSessions(ctx context.Context, req *spannerpb.ListSessionsRequest) (*spannerpb.ListSessionsResponse, error) {
629	s.mu.Lock()
630	if s.stopped {
631		s.mu.Unlock()
632		return nil, gstatus.Error(codes.Unavailable, "server has been stopped")
633	}
634	s.receivedRequests <- req
635	s.mu.Unlock()
636	if req.Database == "" {
637		return nil, gstatus.Error(codes.InvalidArgument, "Missing database")
638	}
639	expectedSessionName := req.Database + "/sessions/"
640	var sessions []*spannerpb.Session
641	s.mu.Lock()
642	for _, session := range s.sessions {
643		if strings.Index(session.Name, expectedSessionName) == 0 {
644			sessions = append(sessions, session)
645		}
646	}
647	s.mu.Unlock()
648	sort.Slice(sessions[:], func(i, j int) bool {
649		return sessions[i].Name < sessions[j].Name
650	})
651	res := &spannerpb.ListSessionsResponse{Sessions: sessions}
652	return res, nil
653}
654
655func (s *inMemSpannerServer) DeleteSession(ctx context.Context, req *spannerpb.DeleteSessionRequest) (*emptypb.Empty, error) {
656	if err := s.simulateExecutionTime(MethodDeleteSession, req); err != nil {
657		return nil, err
658	}
659	if req.Name == "" {
660		return nil, gstatus.Error(codes.InvalidArgument, "Missing session name")
661	}
662	if _, err := s.findSession(req.Name); err != nil {
663		return nil, err
664	}
665	s.mu.Lock()
666	defer s.mu.Unlock()
667	s.totalSessionsDeleted++
668	delete(s.sessions, req.Name)
669	return &emptypb.Empty{}, nil
670}
671
672func (s *inMemSpannerServer) ExecuteSql(ctx context.Context, req *spannerpb.ExecuteSqlRequest) (*spannerpb.ResultSet, error) {
673	if err := s.simulateExecutionTime(MethodExecuteSql, req); err != nil {
674		return nil, err
675	}
676	if req.Session == "" {
677		return nil, gstatus.Error(codes.InvalidArgument, "Missing session name")
678	}
679	session, err := s.findSession(req.Session)
680	if err != nil {
681		return nil, err
682	}
683	var id []byte
684	s.updateSessionLastUseTime(session.Name)
685	if id = s.getTransactionID(session, req.Transaction); id != nil {
686		_, err = s.getTransactionByID(id)
687		if err != nil {
688			return nil, err
689		}
690	}
691	statementResult, err := s.getStatementResult(req.Sql)
692	if err != nil {
693		return nil, err
694	}
695	s.mu.Lock()
696	isPartitionedDml := s.partitionedDmlTransactions[string(id)]
697	s.mu.Unlock()
698	switch statementResult.Type {
699	case StatementResultError:
700		return nil, statementResult.Err
701	case StatementResultResultSet:
702		return statementResult.ResultSet, nil
703	case StatementResultUpdateCount:
704		return statementResult.convertUpdateCountToResultSet(!isPartitionedDml), nil
705	}
706	return nil, gstatus.Error(codes.Internal, "Unknown result type")
707}
708
709func (s *inMemSpannerServer) ExecuteStreamingSql(req *spannerpb.ExecuteSqlRequest, stream spannerpb.Spanner_ExecuteStreamingSqlServer) error {
710	if err := s.simulateExecutionTime(MethodExecuteStreamingSql, req); err != nil {
711		return err
712	}
713	if req.Session == "" {
714		return gstatus.Error(codes.InvalidArgument, "Missing session name")
715	}
716	session, err := s.findSession(req.Session)
717	if err != nil {
718		return err
719	}
720	s.updateSessionLastUseTime(session.Name)
721	var id []byte
722	if id = s.getTransactionID(session, req.Transaction); id != nil {
723		_, err = s.getTransactionByID(id)
724		if err != nil {
725			return err
726		}
727	}
728	statementResult, err := s.getStatementResult(req.Sql)
729	if err != nil {
730		return err
731	}
732	s.mu.Lock()
733	isPartitionedDml := s.partitionedDmlTransactions[string(id)]
734	s.mu.Unlock()
735	switch statementResult.Type {
736	case StatementResultError:
737		return statementResult.Err
738	case StatementResultResultSet:
739		parts, err := statementResult.toPartialResultSets(req.ResumeToken)
740		if err != nil {
741			return err
742		}
743		var nextPartialResultSetError *PartialResultSetExecutionTime
744		s.mu.Lock()
745		pErrors := s.partialResultSetErrors[req.Sql]
746		if len(pErrors) > 0 {
747			nextPartialResultSetError = pErrors[0]
748			s.partialResultSetErrors[req.Sql] = pErrors[1:]
749		}
750		s.mu.Unlock()
751		for _, part := range parts {
752			if nextPartialResultSetError != nil && bytes.Equal(part.ResumeToken, nextPartialResultSetError.ResumeToken) {
753				if nextPartialResultSetError.ExecutionTime > 0 {
754					<-time.After(nextPartialResultSetError.ExecutionTime)
755				}
756				if nextPartialResultSetError.Err != nil {
757					return nextPartialResultSetError.Err
758				}
759			}
760			if err := stream.Send(part); err != nil {
761				return err
762			}
763		}
764		return nil
765	case StatementResultUpdateCount:
766		part := statementResult.updateCountToPartialResultSet(!isPartitionedDml)
767		if err := stream.Send(part); err != nil {
768			return err
769		}
770		return nil
771	}
772	return gstatus.Error(codes.Internal, "Unknown result type")
773}
774
775func (s *inMemSpannerServer) ExecuteBatchDml(ctx context.Context, req *spannerpb.ExecuteBatchDmlRequest) (*spannerpb.ExecuteBatchDmlResponse, error) {
776	s.mu.Lock()
777	if s.stopped {
778		s.mu.Unlock()
779		return nil, gstatus.Error(codes.Unavailable, "server has been stopped")
780	}
781	s.receivedRequests <- req
782	s.mu.Unlock()
783	if req.Session == "" {
784		return nil, gstatus.Error(codes.InvalidArgument, "Missing session name")
785	}
786	session, err := s.findSession(req.Session)
787	if err != nil {
788		return nil, err
789	}
790	s.updateSessionLastUseTime(session.Name)
791	var id []byte
792	if id = s.getTransactionID(session, req.Transaction); id != nil {
793		_, err = s.getTransactionByID(id)
794		if err != nil {
795			return nil, err
796		}
797	}
798	s.mu.Lock()
799	isPartitionedDml := s.partitionedDmlTransactions[string(id)]
800	s.mu.Unlock()
801	resp := &spannerpb.ExecuteBatchDmlResponse{}
802	resp.ResultSets = make([]*spannerpb.ResultSet, len(req.Statements))
803	for idx, batchStatement := range req.Statements {
804		statementResult, err := s.getStatementResult(batchStatement.Sql)
805		if err != nil {
806			return nil, err
807		}
808		switch statementResult.Type {
809		case StatementResultError:
810			resp.Status = &status.Status{Code: int32(codes.Unknown)}
811		case StatementResultResultSet:
812			return nil, gstatus.Error(codes.InvalidArgument, fmt.Sprintf("Not an update statement: %v", batchStatement.Sql))
813		case StatementResultUpdateCount:
814			resp.ResultSets[idx] = statementResult.convertUpdateCountToResultSet(!isPartitionedDml)
815			resp.Status = &status.Status{Code: int32(codes.OK)}
816		}
817	}
818	return resp, nil
819}
820
821func (s *inMemSpannerServer) Read(ctx context.Context, req *spannerpb.ReadRequest) (*spannerpb.ResultSet, error) {
822	s.mu.Lock()
823	if s.stopped {
824		s.mu.Unlock()
825		return nil, gstatus.Error(codes.Unavailable, "server has been stopped")
826	}
827	s.receivedRequests <- req
828	s.mu.Unlock()
829	return nil, gstatus.Error(codes.Unimplemented, "Method not yet implemented")
830}
831
832func (s *inMemSpannerServer) StreamingRead(req *spannerpb.ReadRequest, stream spannerpb.Spanner_StreamingReadServer) error {
833	s.mu.Lock()
834	if s.stopped {
835		s.mu.Unlock()
836		return gstatus.Error(codes.Unavailable, "server has been stopped")
837	}
838	s.receivedRequests <- req
839	s.mu.Unlock()
840	return gstatus.Error(codes.Unimplemented, "Method not yet implemented")
841}
842
843func (s *inMemSpannerServer) BeginTransaction(ctx context.Context, req *spannerpb.BeginTransactionRequest) (*spannerpb.Transaction, error) {
844	if err := s.simulateExecutionTime(MethodBeginTransaction, req); err != nil {
845		return nil, err
846	}
847	if req.Session == "" {
848		return nil, gstatus.Error(codes.InvalidArgument, "Missing session name")
849	}
850	session, err := s.findSession(req.Session)
851	if err != nil {
852		return nil, err
853	}
854	s.updateSessionLastUseTime(session.Name)
855	tx := s.beginTransaction(session, req.Options)
856	return tx, nil
857}
858
859func (s *inMemSpannerServer) Commit(ctx context.Context, req *spannerpb.CommitRequest) (*spannerpb.CommitResponse, error) {
860	if err := s.simulateExecutionTime(MethodCommitTransaction, req); err != nil {
861		return nil, err
862	}
863	if req.Session == "" {
864		return nil, gstatus.Error(codes.InvalidArgument, "Missing session name")
865	}
866	session, err := s.findSession(req.Session)
867	if err != nil {
868		return nil, err
869	}
870	s.updateSessionLastUseTime(session.Name)
871	var tx *spannerpb.Transaction
872	if req.GetSingleUseTransaction() != nil {
873		tx = s.beginTransaction(session, req.GetSingleUseTransaction())
874	} else if req.GetTransactionId() != nil {
875		tx, err = s.getTransactionByID(req.GetTransactionId())
876		if err != nil {
877			return nil, err
878		}
879	} else {
880		return nil, gstatus.Error(codes.InvalidArgument, "Missing transaction in commit request")
881	}
882	s.removeTransaction(tx)
883	return &spannerpb.CommitResponse{CommitTimestamp: getCurrentTimestamp()}, nil
884}
885
886func (s *inMemSpannerServer) Rollback(ctx context.Context, req *spannerpb.RollbackRequest) (*emptypb.Empty, error) {
887	s.mu.Lock()
888	if s.stopped {
889		s.mu.Unlock()
890		return nil, gstatus.Error(codes.Unavailable, "server has been stopped")
891	}
892	s.receivedRequests <- req
893	s.mu.Unlock()
894	if req.Session == "" {
895		return nil, gstatus.Error(codes.InvalidArgument, "Missing session name")
896	}
897	session, err := s.findSession(req.Session)
898	if err != nil {
899		return nil, err
900	}
901	s.updateSessionLastUseTime(session.Name)
902	tx, err := s.getTransactionByID(req.TransactionId)
903	if err != nil {
904		return nil, err
905	}
906	s.removeTransaction(tx)
907	return &emptypb.Empty{}, nil
908}
909
910func (s *inMemSpannerServer) PartitionQuery(ctx context.Context, req *spannerpb.PartitionQueryRequest) (*spannerpb.PartitionResponse, error) {
911	s.mu.Lock()
912	if s.stopped {
913		s.mu.Unlock()
914		return nil, gstatus.Error(codes.Unavailable, "server has been stopped")
915	}
916	s.receivedRequests <- req
917	s.mu.Unlock()
918	return nil, gstatus.Error(codes.Unimplemented, "Method not yet implemented")
919}
920
921func (s *inMemSpannerServer) PartitionRead(ctx context.Context, req *spannerpb.PartitionReadRequest) (*spannerpb.PartitionResponse, error) {
922	s.mu.Lock()
923	if s.stopped {
924		s.mu.Unlock()
925		return nil, gstatus.Error(codes.Unavailable, "server has been stopped")
926	}
927	s.receivedRequests <- req
928	s.mu.Unlock()
929	return nil, gstatus.Error(codes.Unimplemented, "Method not yet implemented")
930}
931