1// Copyright 2019 Google LLC
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     https://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15package testutil
16
17import (
18	"bytes"
19	"context"
20	"fmt"
21	"math/rand"
22	"sort"
23	"strings"
24	"sync"
25	"time"
26
27	"github.com/golang/protobuf/ptypes"
28	emptypb "github.com/golang/protobuf/ptypes/empty"
29	structpb "github.com/golang/protobuf/ptypes/struct"
30	"github.com/golang/protobuf/ptypes/timestamp"
31	"google.golang.org/genproto/googleapis/rpc/errdetails"
32	"google.golang.org/genproto/googleapis/rpc/status"
33	spannerpb "google.golang.org/genproto/googleapis/spanner/v1"
34	"google.golang.org/grpc/codes"
35	gstatus "google.golang.org/grpc/status"
36)
37
38// StatementResultType indicates the type of result returned by a SQL
39// statement.
40type StatementResultType int
41
42const (
43	// StatementResultError indicates that the sql statement returns an error.
44	StatementResultError StatementResultType = 0
45	// StatementResultResultSet indicates that the sql statement returns a
46	// result set.
47	StatementResultResultSet StatementResultType = 1
48	// StatementResultUpdateCount indicates that the sql statement returns an
49	// update count.
50	StatementResultUpdateCount StatementResultType = 2
51	// MaxRowsPerPartialResultSet is the maximum number of rows returned in
52	// each PartialResultSet. This number is deliberately set to a low value to
53	// ensure that most queries return more than one PartialResultSet.
54	MaxRowsPerPartialResultSet = 1
55)
56
57// The method names that can be used to register execution times and errors.
58const (
59	MethodBeginTransaction    string = "BEGIN_TRANSACTION"
60	MethodCommitTransaction   string = "COMMIT_TRANSACTION"
61	MethodBatchCreateSession  string = "BATCH_CREATE_SESSION"
62	MethodCreateSession       string = "CREATE_SESSION"
63	MethodDeleteSession       string = "DELETE_SESSION"
64	MethodGetSession          string = "GET_SESSION"
65	MethodExecuteSql          string = "EXECUTE_SQL"
66	MethodExecuteStreamingSql string = "EXECUTE_STREAMING_SQL"
67	MethodExecuteBatchDml     string = "EXECUTE_BATCH_DML"
68)
69
70// StatementResult represents a mocked result on the test server. The result is
71// either of: a ResultSet, an update count or an error.
72type StatementResult struct {
73	Type        StatementResultType
74	Err         error
75	ResultSet   *spannerpb.ResultSet
76	UpdateCount int64
77}
78
79// PartialResultSetExecutionTime represents execution times and errors that
80// should be used when a PartialResult at the specified resume token is to
81// be returned.
82type PartialResultSetExecutionTime struct {
83	ResumeToken   []byte
84	ExecutionTime time.Duration
85	Err           error
86}
87
88// Converts a ResultSet to a PartialResultSet. This method is used to convert
89// a mocked result to a PartialResultSet when one of the streaming methods are
90// called.
91func (s *StatementResult) toPartialResultSets(resumeToken []byte) (result []*spannerpb.PartialResultSet, err error) {
92	var startIndex uint64
93	if len(resumeToken) > 0 {
94		if startIndex, err = DecodeResumeToken(resumeToken); err != nil {
95			return nil, err
96		}
97	}
98
99	totalRows := uint64(len(s.ResultSet.Rows))
100	for {
101		rowCount := min(totalRows-startIndex, uint64(MaxRowsPerPartialResultSet))
102		rows := s.ResultSet.Rows[startIndex : startIndex+rowCount]
103		values := make([]*structpb.Value,
104			len(rows)*len(s.ResultSet.Metadata.RowType.Fields))
105		var idx int
106		for _, row := range rows {
107			for colIdx := range s.ResultSet.Metadata.RowType.Fields {
108				values[idx] = row.Values[colIdx]
109				idx++
110			}
111		}
112		result = append(result, &spannerpb.PartialResultSet{
113			Metadata:    s.ResultSet.Metadata,
114			Values:      values,
115			ResumeToken: EncodeResumeToken(startIndex + rowCount),
116		})
117		startIndex += rowCount
118		if startIndex == totalRows {
119			break
120		}
121	}
122	return result, nil
123}
124
125func min(x, y uint64) uint64 {
126	if x > y {
127		return y
128	}
129	return x
130}
131
132func (s *StatementResult) updateCountToPartialResultSet(exact bool) *spannerpb.PartialResultSet {
133	return &spannerpb.PartialResultSet{
134		Stats: s.convertUpdateCountToResultSet(exact).Stats,
135	}
136}
137
138// Converts an update count to a ResultSet, as DML statements also return the
139// update count as the statistics of a ResultSet.
140func (s *StatementResult) convertUpdateCountToResultSet(exact bool) *spannerpb.ResultSet {
141	if exact {
142		return &spannerpb.ResultSet{
143			Stats: &spannerpb.ResultSetStats{
144				RowCount: &spannerpb.ResultSetStats_RowCountExact{
145					RowCountExact: s.UpdateCount,
146				},
147			},
148		}
149	}
150	return &spannerpb.ResultSet{
151		Stats: &spannerpb.ResultSetStats{
152			RowCount: &spannerpb.ResultSetStats_RowCountLowerBound{
153				RowCountLowerBound: s.UpdateCount,
154			},
155		},
156	}
157}
158
159// SimulatedExecutionTime represents the time the execution of a method
160// should take, and any errors that should be returned by the method.
161type SimulatedExecutionTime struct {
162	MinimumExecutionTime time.Duration
163	RandomExecutionTime  time.Duration
164	Errors               []error
165	// Keep error after execution. The error will continue to be returned until
166	// it is cleared.
167	KeepError bool
168}
169
170// InMemSpannerServer contains the SpannerServer interface plus a couple
171// of specific methods for adding mocked results and resetting the server.
172type InMemSpannerServer interface {
173	spannerpb.SpannerServer
174
175	// Stops this server.
176	Stop()
177
178	// Resets the in-mem server to its default state, deleting all sessions and
179	// transactions that have been created on the server. Mocked results are
180	// not deleted.
181	Reset()
182
183	// Sets an error that will be returned by the next server call. The server
184	// call will also automatically clear the error.
185	SetError(err error)
186
187	// Puts a mocked result on the server for a specific sql statement. The
188	// server does not parse the SQL string in any way, it is merely used as
189	// a key to the mocked result. The result will be used for all methods that
190	// expect a SQL statement, including (batch) DML methods.
191	PutStatementResult(sql string, result *StatementResult) error
192
193	// Adds a PartialResultSetExecutionTime to the server that should be returned
194	// for the specified SQL string.
195	AddPartialResultSetError(sql string, err PartialResultSetExecutionTime)
196
197	// Removes a mocked result on the server for a specific sql statement.
198	RemoveStatementResult(sql string)
199
200	// Aborts the specified transaction . This method can be used to test
201	// transaction retry logic.
202	AbortTransaction(id []byte)
203
204	// Puts a simulated execution time for one of the Spanner methods.
205	PutExecutionTime(method string, executionTime SimulatedExecutionTime)
206	// Freeze stalls all requests.
207	Freeze()
208	// Unfreeze restores processing requests.
209	Unfreeze()
210
211	TotalSessionsCreated() uint
212	TotalSessionsDeleted() uint
213	SetMaxSessionsReturnedByServerPerBatchRequest(sessionCount int32)
214	SetMaxSessionsReturnedByServerInTotal(sessionCount int32)
215
216	ReceivedRequests() chan interface{}
217	DumpSessions() map[string]bool
218	ClearPings()
219	DumpPings() []string
220}
221
222type inMemSpannerServer struct {
223	// Embed for forward compatibility.
224	// Tests will keep working if more methods are added
225	// in the future.
226	spannerpb.SpannerServer
227
228	mu sync.Mutex
229	// Set to true when this server been stopped. This is the end state of a
230	// server, a stopped server cannot be restarted.
231	stopped bool
232	// If set, all calls return this error.
233	err error
234	// The mock server creates session IDs using this counter.
235	sessionCounter uint64
236	// The sessions that have been created on this mock server.
237	sessions map[string]*spannerpb.Session
238	// Last use times per session.
239	sessionLastUseTime map[string]time.Time
240	// The mock server creates transaction IDs per session using these
241	// counters.
242	transactionCounters map[string]*uint64
243	// The transactions that have been created on this mock server.
244	transactions map[string]*spannerpb.Transaction
245	// The transactions that have been (manually) aborted on the server.
246	abortedTransactions map[string]bool
247	// The transactions that are marked as PartitionedDMLTransaction
248	partitionedDmlTransactions map[string]bool
249	// The mocked results for this server.
250	statementResults map[string]*StatementResult
251	// The simulated execution times per method.
252	executionTimes map[string]*SimulatedExecutionTime
253	// The simulated errors for partial result sets
254	partialResultSetErrors map[string][]*PartialResultSetExecutionTime
255
256	totalSessionsCreated uint
257	totalSessionsDeleted uint
258	// The maximum number of sessions that will be created per batch request.
259	maxSessionsReturnedByServerPerBatchRequest int32
260	maxSessionsReturnedByServerInTotal         int32
261	receivedRequests                           chan interface{}
262	// Session ping history.
263	pings []string
264
265	// Server will stall on any requests.
266	freezed chan struct{}
267}
268
269// NewInMemSpannerServer creates a new in-mem test server.
270func NewInMemSpannerServer() InMemSpannerServer {
271	res := &inMemSpannerServer{}
272	res.initDefaults()
273	res.statementResults = make(map[string]*StatementResult)
274	res.executionTimes = make(map[string]*SimulatedExecutionTime)
275	res.partialResultSetErrors = make(map[string][]*PartialResultSetExecutionTime)
276	res.receivedRequests = make(chan interface{}, 1000000)
277	// Produce a closed channel, so the default action of ready is to not block.
278	res.Freeze()
279	res.Unfreeze()
280	return res
281}
282
283func (s *inMemSpannerServer) Stop() {
284	s.mu.Lock()
285	defer s.mu.Unlock()
286	s.stopped = true
287	close(s.receivedRequests)
288}
289
290// Resets the test server to its initial state, deleting all sessions and
291// transactions that have been created on the server. This method will not
292// remove mocked results.
293func (s *inMemSpannerServer) Reset() {
294	s.mu.Lock()
295	defer s.mu.Unlock()
296	close(s.receivedRequests)
297	s.receivedRequests = make(chan interface{}, 1000000)
298	s.initDefaults()
299}
300
301func (s *inMemSpannerServer) SetError(err error) {
302	s.mu.Lock()
303	defer s.mu.Unlock()
304	s.err = err
305}
306
307// Registers a mocked result for a SQL statement on the server.
308func (s *inMemSpannerServer) PutStatementResult(sql string, result *StatementResult) error {
309	s.mu.Lock()
310	defer s.mu.Unlock()
311	s.statementResults[sql] = result
312	return nil
313}
314
315func (s *inMemSpannerServer) RemoveStatementResult(sql string) {
316	s.mu.Lock()
317	defer s.mu.Unlock()
318	delete(s.statementResults, sql)
319}
320
321func (s *inMemSpannerServer) AbortTransaction(id []byte) {
322	s.mu.Lock()
323	defer s.mu.Unlock()
324	s.abortedTransactions[string(id)] = true
325}
326
327func (s *inMemSpannerServer) PutExecutionTime(method string, executionTime SimulatedExecutionTime) {
328	s.mu.Lock()
329	defer s.mu.Unlock()
330	s.executionTimes[method] = &executionTime
331}
332
333func (s *inMemSpannerServer) AddPartialResultSetError(sql string, partialResultSetError PartialResultSetExecutionTime) {
334	s.mu.Lock()
335	defer s.mu.Unlock()
336	s.partialResultSetErrors[sql] = append(s.partialResultSetErrors[sql], &partialResultSetError)
337}
338
339// Freeze stalls all requests.
340func (s *inMemSpannerServer) Freeze() {
341	s.mu.Lock()
342	defer s.mu.Unlock()
343	s.freezed = make(chan struct{})
344}
345
346// Unfreeze restores processing requests.
347func (s *inMemSpannerServer) Unfreeze() {
348	s.mu.Lock()
349	defer s.mu.Unlock()
350	close(s.freezed)
351}
352
353// ready checks conditions before executing requests
354func (s *inMemSpannerServer) ready() {
355	s.mu.Lock()
356	freezed := s.freezed
357	s.mu.Unlock()
358	// check if server should be freezed
359	<-freezed
360}
361
362func (s *inMemSpannerServer) TotalSessionsCreated() uint {
363	s.mu.Lock()
364	defer s.mu.Unlock()
365	return s.totalSessionsCreated
366}
367
368func (s *inMemSpannerServer) TotalSessionsDeleted() uint {
369	s.mu.Lock()
370	defer s.mu.Unlock()
371	return s.totalSessionsDeleted
372}
373
374func (s *inMemSpannerServer) SetMaxSessionsReturnedByServerPerBatchRequest(sessionCount int32) {
375	s.mu.Lock()
376	defer s.mu.Unlock()
377	s.maxSessionsReturnedByServerPerBatchRequest = sessionCount
378}
379
380func (s *inMemSpannerServer) SetMaxSessionsReturnedByServerInTotal(sessionCount int32) {
381	s.mu.Lock()
382	defer s.mu.Unlock()
383	s.maxSessionsReturnedByServerInTotal = sessionCount
384}
385
386func (s *inMemSpannerServer) ReceivedRequests() chan interface{} {
387	return s.receivedRequests
388}
389
390// ClearPings clears the ping history from the server.
391func (s *inMemSpannerServer) ClearPings() {
392	s.mu.Lock()
393	defer s.mu.Unlock()
394	s.pings = nil
395}
396
397// DumpPings dumps the ping history.
398func (s *inMemSpannerServer) DumpPings() []string {
399	s.mu.Lock()
400	defer s.mu.Unlock()
401	return append([]string(nil), s.pings...)
402}
403
404// DumpSessions dumps the internal session table.
405func (s *inMemSpannerServer) DumpSessions() map[string]bool {
406	s.mu.Lock()
407	defer s.mu.Unlock()
408	st := map[string]bool{}
409	for s := range s.sessions {
410		st[s] = true
411	}
412	return st
413}
414
415func (s *inMemSpannerServer) initDefaults() {
416	s.sessionCounter = 0
417	s.maxSessionsReturnedByServerPerBatchRequest = 100
418	s.sessions = make(map[string]*spannerpb.Session)
419	s.sessionLastUseTime = make(map[string]time.Time)
420	s.transactions = make(map[string]*spannerpb.Transaction)
421	s.abortedTransactions = make(map[string]bool)
422	s.partitionedDmlTransactions = make(map[string]bool)
423	s.transactionCounters = make(map[string]*uint64)
424}
425
426func (s *inMemSpannerServer) generateSessionNameLocked(database string) string {
427	s.sessionCounter++
428	return fmt.Sprintf("%s/sessions/%d", database, s.sessionCounter)
429}
430
431func (s *inMemSpannerServer) findSession(name string) (*spannerpb.Session, error) {
432	s.mu.Lock()
433	defer s.mu.Unlock()
434	session := s.sessions[name]
435	if session == nil {
436		return nil, newSessionNotFoundError(name)
437	}
438	return session, nil
439}
440
441// sessionResourceType is the type name of Spanner sessions.
442const sessionResourceType = "type.googleapis.com/google.spanner.v1.Session"
443
444func newSessionNotFoundError(name string) error {
445	s := gstatus.Newf(codes.NotFound, "Session not found: Session with id %s not found", name)
446	s, _ = s.WithDetails(&errdetails.ResourceInfo{ResourceType: sessionResourceType, ResourceName: name})
447	return s.Err()
448}
449
450func (s *inMemSpannerServer) updateSessionLastUseTime(session string) {
451	s.mu.Lock()
452	defer s.mu.Unlock()
453	s.sessionLastUseTime[session] = time.Now()
454}
455
456func getCurrentTimestamp() *timestamp.Timestamp {
457	t := time.Now()
458	return &timestamp.Timestamp{Seconds: t.Unix(), Nanos: int32(t.Nanosecond())}
459}
460
461// Gets the transaction id from the transaction selector. If the selector
462// specifies that a new transaction should be started, this method will start
463// a new transaction and return the id of that transaction.
464func (s *inMemSpannerServer) getTransactionID(session *spannerpb.Session, txSelector *spannerpb.TransactionSelector) []byte {
465	var res []byte
466	if txSelector.GetBegin() != nil {
467		// Start a new transaction.
468		res = s.beginTransaction(session, txSelector.GetBegin()).Id
469	} else if txSelector.GetId() != nil {
470		res = txSelector.GetId()
471	}
472	return res
473}
474
475func (s *inMemSpannerServer) generateTransactionName(session string) string {
476	s.mu.Lock()
477	defer s.mu.Unlock()
478	counter, ok := s.transactionCounters[session]
479	if !ok {
480		counter = new(uint64)
481		s.transactionCounters[session] = counter
482	}
483	*counter++
484	return fmt.Sprintf("%s/transactions/%d", session, *counter)
485}
486
487func (s *inMemSpannerServer) beginTransaction(session *spannerpb.Session, options *spannerpb.TransactionOptions) *spannerpb.Transaction {
488	id := s.generateTransactionName(session.Name)
489	res := &spannerpb.Transaction{
490		Id:            []byte(id),
491		ReadTimestamp: getCurrentTimestamp(),
492	}
493	s.mu.Lock()
494	s.transactions[id] = res
495	s.partitionedDmlTransactions[id] = options.GetPartitionedDml() != nil
496	s.mu.Unlock()
497	return res
498}
499
500func (s *inMemSpannerServer) getTransactionByID(id []byte) (*spannerpb.Transaction, error) {
501	s.mu.Lock()
502	defer s.mu.Unlock()
503	tx, ok := s.transactions[string(id)]
504	if !ok {
505		return nil, gstatus.Error(codes.NotFound, "Transaction not found")
506	}
507	aborted, ok := s.abortedTransactions[string(id)]
508	if ok && aborted {
509		return nil, newAbortedErrorWithMinimalRetryDelay()
510	}
511	return tx, nil
512}
513
514func newAbortedErrorWithMinimalRetryDelay() error {
515	st := gstatus.New(codes.Aborted, "Transaction has been aborted")
516	retry := &errdetails.RetryInfo{
517		RetryDelay: ptypes.DurationProto(time.Nanosecond),
518	}
519	st, _ = st.WithDetails(retry)
520	return st.Err()
521}
522
523func (s *inMemSpannerServer) removeTransaction(tx *spannerpb.Transaction) {
524	s.mu.Lock()
525	defer s.mu.Unlock()
526	delete(s.transactions, string(tx.Id))
527	delete(s.partitionedDmlTransactions, string(tx.Id))
528}
529
530func (s *inMemSpannerServer) getStatementResult(sql string) (*StatementResult, error) {
531	s.mu.Lock()
532	defer s.mu.Unlock()
533	result, ok := s.statementResults[sql]
534	if !ok {
535		return nil, gstatus.Error(codes.Internal, fmt.Sprintf("No result found for statement %v", sql))
536	}
537	return result, nil
538}
539
540func (s *inMemSpannerServer) simulateExecutionTime(method string, req interface{}) error {
541	s.mu.Lock()
542	if s.stopped {
543		s.mu.Unlock()
544		return gstatus.Error(codes.Unavailable, "server has been stopped")
545	}
546	s.receivedRequests <- req
547	s.mu.Unlock()
548	s.ready()
549	s.mu.Lock()
550	if s.err != nil {
551		err := s.err
552		s.err = nil
553		s.mu.Unlock()
554		return err
555	}
556	executionTime, ok := s.executionTimes[method]
557	s.mu.Unlock()
558	if ok {
559		var randTime int64
560		if executionTime.RandomExecutionTime > 0 {
561			randTime = rand.Int63n(int64(executionTime.RandomExecutionTime))
562		}
563		totalExecutionTime := time.Duration(int64(executionTime.MinimumExecutionTime) + randTime)
564		<-time.After(totalExecutionTime)
565		s.mu.Lock()
566		if executionTime.Errors != nil && len(executionTime.Errors) > 0 {
567			err := executionTime.Errors[0]
568			if !executionTime.KeepError {
569				executionTime.Errors = executionTime.Errors[1:]
570			}
571			s.mu.Unlock()
572			return err
573		}
574		s.mu.Unlock()
575	}
576	return nil
577}
578
579func (s *inMemSpannerServer) CreateSession(ctx context.Context, req *spannerpb.CreateSessionRequest) (*spannerpb.Session, error) {
580	if err := s.simulateExecutionTime(MethodCreateSession, req); err != nil {
581		return nil, err
582	}
583	if req.Database == "" {
584		return nil, gstatus.Error(codes.InvalidArgument, "Missing database")
585	}
586	s.mu.Lock()
587	defer s.mu.Unlock()
588	if s.maxSessionsReturnedByServerInTotal > int32(0) && int32(len(s.sessions)) == s.maxSessionsReturnedByServerInTotal {
589		return nil, gstatus.Error(codes.ResourceExhausted, "No more sessions available")
590	}
591	sessionName := s.generateSessionNameLocked(req.Database)
592	ts := getCurrentTimestamp()
593	session := &spannerpb.Session{Name: sessionName, CreateTime: ts, ApproximateLastUseTime: ts}
594	s.totalSessionsCreated++
595	s.sessions[sessionName] = session
596	return session, nil
597}
598
599func (s *inMemSpannerServer) BatchCreateSessions(ctx context.Context, req *spannerpb.BatchCreateSessionsRequest) (*spannerpb.BatchCreateSessionsResponse, error) {
600	if err := s.simulateExecutionTime(MethodBatchCreateSession, req); err != nil {
601		return nil, err
602	}
603	if req.Database == "" {
604		return nil, gstatus.Error(codes.InvalidArgument, "Missing database")
605	}
606	if req.SessionCount <= 0 {
607		return nil, gstatus.Error(codes.InvalidArgument, "Session count must be >= 0")
608	}
609	sessionsToCreate := req.SessionCount
610	s.mu.Lock()
611	defer s.mu.Unlock()
612	if s.maxSessionsReturnedByServerInTotal > int32(0) && int32(len(s.sessions)) >= s.maxSessionsReturnedByServerInTotal {
613		return nil, gstatus.Error(codes.ResourceExhausted, "No more sessions available")
614	}
615	if sessionsToCreate > s.maxSessionsReturnedByServerPerBatchRequest {
616		sessionsToCreate = s.maxSessionsReturnedByServerPerBatchRequest
617	}
618	if s.maxSessionsReturnedByServerInTotal > int32(0) && (sessionsToCreate+int32(len(s.sessions))) > s.maxSessionsReturnedByServerInTotal {
619		sessionsToCreate = s.maxSessionsReturnedByServerInTotal - int32(len(s.sessions))
620	}
621	sessions := make([]*spannerpb.Session, sessionsToCreate)
622	for i := int32(0); i < sessionsToCreate; i++ {
623		sessionName := s.generateSessionNameLocked(req.Database)
624		ts := getCurrentTimestamp()
625		sessions[i] = &spannerpb.Session{Name: sessionName, CreateTime: ts, ApproximateLastUseTime: ts}
626		s.totalSessionsCreated++
627		s.sessions[sessionName] = sessions[i]
628	}
629	return &spannerpb.BatchCreateSessionsResponse{Session: sessions}, nil
630}
631
632func (s *inMemSpannerServer) GetSession(ctx context.Context, req *spannerpb.GetSessionRequest) (*spannerpb.Session, error) {
633	if err := s.simulateExecutionTime(MethodGetSession, req); err != nil {
634		return nil, err
635	}
636	s.mu.Lock()
637	s.pings = append(s.pings, req.Name)
638	s.mu.Unlock()
639	if req.Name == "" {
640		return nil, gstatus.Error(codes.InvalidArgument, "Missing session name")
641	}
642	session, err := s.findSession(req.Name)
643	if err != nil {
644		return nil, err
645	}
646	return session, nil
647}
648
649func (s *inMemSpannerServer) ListSessions(ctx context.Context, req *spannerpb.ListSessionsRequest) (*spannerpb.ListSessionsResponse, error) {
650	s.mu.Lock()
651	if s.stopped {
652		s.mu.Unlock()
653		return nil, gstatus.Error(codes.Unavailable, "server has been stopped")
654	}
655	s.receivedRequests <- req
656	s.mu.Unlock()
657	if req.Database == "" {
658		return nil, gstatus.Error(codes.InvalidArgument, "Missing database")
659	}
660	expectedSessionName := req.Database + "/sessions/"
661	var sessions []*spannerpb.Session
662	s.mu.Lock()
663	for _, session := range s.sessions {
664		if strings.Index(session.Name, expectedSessionName) == 0 {
665			sessions = append(sessions, session)
666		}
667	}
668	s.mu.Unlock()
669	sort.Slice(sessions[:], func(i, j int) bool {
670		return sessions[i].Name < sessions[j].Name
671	})
672	res := &spannerpb.ListSessionsResponse{Sessions: sessions}
673	return res, nil
674}
675
676func (s *inMemSpannerServer) DeleteSession(ctx context.Context, req *spannerpb.DeleteSessionRequest) (*emptypb.Empty, error) {
677	if err := s.simulateExecutionTime(MethodDeleteSession, req); err != nil {
678		return nil, err
679	}
680	if req.Name == "" {
681		return nil, gstatus.Error(codes.InvalidArgument, "Missing session name")
682	}
683	if _, err := s.findSession(req.Name); err != nil {
684		return nil, err
685	}
686	s.mu.Lock()
687	defer s.mu.Unlock()
688	s.totalSessionsDeleted++
689	delete(s.sessions, req.Name)
690	return &emptypb.Empty{}, nil
691}
692
693func (s *inMemSpannerServer) ExecuteSql(ctx context.Context, req *spannerpb.ExecuteSqlRequest) (*spannerpb.ResultSet, error) {
694	if err := s.simulateExecutionTime(MethodExecuteSql, req); err != nil {
695		return nil, err
696	}
697	if req.Session == "" {
698		return nil, gstatus.Error(codes.InvalidArgument, "Missing session name")
699	}
700	session, err := s.findSession(req.Session)
701	if err != nil {
702		return nil, err
703	}
704	var id []byte
705	s.updateSessionLastUseTime(session.Name)
706	if id = s.getTransactionID(session, req.Transaction); id != nil {
707		_, err = s.getTransactionByID(id)
708		if err != nil {
709			return nil, err
710		}
711	}
712	statementResult, err := s.getStatementResult(req.Sql)
713	if err != nil {
714		return nil, err
715	}
716	s.mu.Lock()
717	isPartitionedDml := s.partitionedDmlTransactions[string(id)]
718	s.mu.Unlock()
719	switch statementResult.Type {
720	case StatementResultError:
721		return nil, statementResult.Err
722	case StatementResultResultSet:
723		return statementResult.ResultSet, nil
724	case StatementResultUpdateCount:
725		return statementResult.convertUpdateCountToResultSet(!isPartitionedDml), nil
726	}
727	return nil, gstatus.Error(codes.Internal, "Unknown result type")
728}
729
730func (s *inMemSpannerServer) ExecuteStreamingSql(req *spannerpb.ExecuteSqlRequest, stream spannerpb.Spanner_ExecuteStreamingSqlServer) error {
731	if err := s.simulateExecutionTime(MethodExecuteStreamingSql, req); err != nil {
732		return err
733	}
734	if req.Session == "" {
735		return gstatus.Error(codes.InvalidArgument, "Missing session name")
736	}
737	session, err := s.findSession(req.Session)
738	if err != nil {
739		return err
740	}
741	s.updateSessionLastUseTime(session.Name)
742	var id []byte
743	if id = s.getTransactionID(session, req.Transaction); id != nil {
744		_, err = s.getTransactionByID(id)
745		if err != nil {
746			return err
747		}
748	}
749	statementResult, err := s.getStatementResult(req.Sql)
750	if err != nil {
751		return err
752	}
753	s.mu.Lock()
754	isPartitionedDml := s.partitionedDmlTransactions[string(id)]
755	s.mu.Unlock()
756	switch statementResult.Type {
757	case StatementResultError:
758		return statementResult.Err
759	case StatementResultResultSet:
760		parts, err := statementResult.toPartialResultSets(req.ResumeToken)
761		if err != nil {
762			return err
763		}
764		var nextPartialResultSetError *PartialResultSetExecutionTime
765		s.mu.Lock()
766		pErrors := s.partialResultSetErrors[req.Sql]
767		if len(pErrors) > 0 {
768			nextPartialResultSetError = pErrors[0]
769			s.partialResultSetErrors[req.Sql] = pErrors[1:]
770		}
771		s.mu.Unlock()
772		for _, part := range parts {
773			if nextPartialResultSetError != nil && bytes.Equal(part.ResumeToken, nextPartialResultSetError.ResumeToken) {
774				if nextPartialResultSetError.ExecutionTime > 0 {
775					<-time.After(nextPartialResultSetError.ExecutionTime)
776				}
777				if nextPartialResultSetError.Err != nil {
778					return nextPartialResultSetError.Err
779				}
780			}
781			if err := stream.Send(part); err != nil {
782				return err
783			}
784		}
785		return nil
786	case StatementResultUpdateCount:
787		part := statementResult.updateCountToPartialResultSet(!isPartitionedDml)
788		if err := stream.Send(part); err != nil {
789			return err
790		}
791		return nil
792	}
793	return gstatus.Error(codes.Internal, "Unknown result type")
794}
795
796func (s *inMemSpannerServer) ExecuteBatchDml(ctx context.Context, req *spannerpb.ExecuteBatchDmlRequest) (*spannerpb.ExecuteBatchDmlResponse, error) {
797	if err := s.simulateExecutionTime(MethodExecuteBatchDml, req); err != nil {
798		return nil, err
799	}
800	if req.Session == "" {
801		return nil, gstatus.Error(codes.InvalidArgument, "Missing session name")
802	}
803	session, err := s.findSession(req.Session)
804	if err != nil {
805		return nil, err
806	}
807	s.updateSessionLastUseTime(session.Name)
808	var id []byte
809	if id = s.getTransactionID(session, req.Transaction); id != nil {
810		_, err = s.getTransactionByID(id)
811		if err != nil {
812			return nil, err
813		}
814	}
815	s.mu.Lock()
816	isPartitionedDml := s.partitionedDmlTransactions[string(id)]
817	s.mu.Unlock()
818	resp := &spannerpb.ExecuteBatchDmlResponse{}
819	resp.ResultSets = make([]*spannerpb.ResultSet, len(req.Statements))
820	for idx, batchStatement := range req.Statements {
821		statementResult, err := s.getStatementResult(batchStatement.Sql)
822		if err != nil {
823			return nil, err
824		}
825		switch statementResult.Type {
826		case StatementResultError:
827			resp.Status = &status.Status{Code: int32(codes.Unknown)}
828		case StatementResultResultSet:
829			return nil, gstatus.Error(codes.InvalidArgument, fmt.Sprintf("Not an update statement: %v", batchStatement.Sql))
830		case StatementResultUpdateCount:
831			resp.ResultSets[idx] = statementResult.convertUpdateCountToResultSet(!isPartitionedDml)
832			resp.Status = &status.Status{Code: int32(codes.OK)}
833		}
834	}
835	return resp, nil
836}
837
838func (s *inMemSpannerServer) Read(ctx context.Context, req *spannerpb.ReadRequest) (*spannerpb.ResultSet, error) {
839	s.mu.Lock()
840	if s.stopped {
841		s.mu.Unlock()
842		return nil, gstatus.Error(codes.Unavailable, "server has been stopped")
843	}
844	s.receivedRequests <- req
845	s.mu.Unlock()
846	return nil, gstatus.Error(codes.Unimplemented, "Method not yet implemented")
847}
848
849func (s *inMemSpannerServer) StreamingRead(req *spannerpb.ReadRequest, stream spannerpb.Spanner_StreamingReadServer) error {
850	s.mu.Lock()
851	if s.stopped {
852		s.mu.Unlock()
853		return gstatus.Error(codes.Unavailable, "server has been stopped")
854	}
855	s.receivedRequests <- req
856	s.mu.Unlock()
857	return gstatus.Error(codes.Unimplemented, "Method not yet implemented")
858}
859
860func (s *inMemSpannerServer) BeginTransaction(ctx context.Context, req *spannerpb.BeginTransactionRequest) (*spannerpb.Transaction, error) {
861	if err := s.simulateExecutionTime(MethodBeginTransaction, req); err != nil {
862		return nil, err
863	}
864	if req.Session == "" {
865		return nil, gstatus.Error(codes.InvalidArgument, "Missing session name")
866	}
867	session, err := s.findSession(req.Session)
868	if err != nil {
869		return nil, err
870	}
871	s.updateSessionLastUseTime(session.Name)
872	tx := s.beginTransaction(session, req.Options)
873	return tx, nil
874}
875
876func (s *inMemSpannerServer) Commit(ctx context.Context, req *spannerpb.CommitRequest) (*spannerpb.CommitResponse, error) {
877	if err := s.simulateExecutionTime(MethodCommitTransaction, req); err != nil {
878		return nil, err
879	}
880	if req.Session == "" {
881		return nil, gstatus.Error(codes.InvalidArgument, "Missing session name")
882	}
883	session, err := s.findSession(req.Session)
884	if err != nil {
885		return nil, err
886	}
887	s.updateSessionLastUseTime(session.Name)
888	var tx *spannerpb.Transaction
889	if req.GetSingleUseTransaction() != nil {
890		tx = s.beginTransaction(session, req.GetSingleUseTransaction())
891	} else if req.GetTransactionId() != nil {
892		tx, err = s.getTransactionByID(req.GetTransactionId())
893		if err != nil {
894			return nil, err
895		}
896	} else {
897		return nil, gstatus.Error(codes.InvalidArgument, "Missing transaction in commit request")
898	}
899	s.removeTransaction(tx)
900	return &spannerpb.CommitResponse{CommitTimestamp: getCurrentTimestamp()}, nil
901}
902
903func (s *inMemSpannerServer) Rollback(ctx context.Context, req *spannerpb.RollbackRequest) (*emptypb.Empty, error) {
904	s.mu.Lock()
905	if s.stopped {
906		s.mu.Unlock()
907		return nil, gstatus.Error(codes.Unavailable, "server has been stopped")
908	}
909	s.receivedRequests <- req
910	s.mu.Unlock()
911	if req.Session == "" {
912		return nil, gstatus.Error(codes.InvalidArgument, "Missing session name")
913	}
914	session, err := s.findSession(req.Session)
915	if err != nil {
916		return nil, err
917	}
918	s.updateSessionLastUseTime(session.Name)
919	tx, err := s.getTransactionByID(req.TransactionId)
920	if err != nil {
921		return nil, err
922	}
923	s.removeTransaction(tx)
924	return &emptypb.Empty{}, nil
925}
926
927func (s *inMemSpannerServer) PartitionQuery(ctx context.Context, req *spannerpb.PartitionQueryRequest) (*spannerpb.PartitionResponse, error) {
928	s.mu.Lock()
929	if s.stopped {
930		s.mu.Unlock()
931		return nil, gstatus.Error(codes.Unavailable, "server has been stopped")
932	}
933	s.receivedRequests <- req
934	s.mu.Unlock()
935	if req.Session == "" {
936		return nil, gstatus.Error(codes.InvalidArgument, "Missing session name")
937	}
938	session, err := s.findSession(req.Session)
939	if err != nil {
940		return nil, err
941	}
942	var id []byte
943	var tx *spannerpb.Transaction
944	s.updateSessionLastUseTime(session.Name)
945	if id = s.getTransactionID(session, req.Transaction); id != nil {
946		tx, err = s.getTransactionByID(id)
947		if err != nil {
948			return nil, err
949		}
950	}
951	var partitions []*spannerpb.Partition
952	for i := int64(0); i < req.PartitionOptions.MaxPartitions; i++ {
953		token := make([]byte, 10)
954		_, err := rand.Read(token)
955		if err != nil {
956			return nil, gstatus.Error(codes.Internal, "failed to generate random partition token")
957		}
958		partitions = append(partitions, &spannerpb.Partition{PartitionToken: token})
959	}
960	return &spannerpb.PartitionResponse{
961		Partitions:  partitions,
962		Transaction: tx,
963	}, nil
964}
965
966func (s *inMemSpannerServer) PartitionRead(ctx context.Context, req *spannerpb.PartitionReadRequest) (*spannerpb.PartitionResponse, error) {
967	s.mu.Lock()
968	if s.stopped {
969		s.mu.Unlock()
970		return nil, gstatus.Error(codes.Unavailable, "server has been stopped")
971	}
972	s.receivedRequests <- req
973	s.mu.Unlock()
974	return nil, gstatus.Error(codes.Unimplemented, "Method not yet implemented")
975}
976