1// Copyright 2019 Google LLC
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     https://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15package testutil
16
17import (
18	"bytes"
19	"context"
20	"fmt"
21	"math/rand"
22	"sort"
23	"strings"
24	"sync"
25	"time"
26
27	"github.com/golang/protobuf/ptypes"
28	emptypb "github.com/golang/protobuf/ptypes/empty"
29	structpb "github.com/golang/protobuf/ptypes/struct"
30	"github.com/golang/protobuf/ptypes/timestamp"
31	"google.golang.org/genproto/googleapis/rpc/errdetails"
32	"google.golang.org/genproto/googleapis/rpc/status"
33	spannerpb "google.golang.org/genproto/googleapis/spanner/v1"
34	"google.golang.org/grpc/codes"
35	gstatus "google.golang.org/grpc/status"
36)
37
38// StatementResultType indicates the type of result returned by a SQL
39// statement.
40type StatementResultType int
41
42const (
43	// StatementResultError indicates that the sql statement returns an error.
44	StatementResultError StatementResultType = 0
45	// StatementResultResultSet indicates that the sql statement returns a
46	// result set.
47	StatementResultResultSet StatementResultType = 1
48	// StatementResultUpdateCount indicates that the sql statement returns an
49	// update count.
50	StatementResultUpdateCount StatementResultType = 2
51	// MaxRowsPerPartialResultSet is the maximum number of rows returned in
52	// each PartialResultSet. This number is deliberately set to a low value to
53	// ensure that most queries return more than one PartialResultSet.
54	MaxRowsPerPartialResultSet = 1
55)
56
57// The method names that can be used to register execution times and errors.
58const (
59	MethodBeginTransaction    string = "BEGIN_TRANSACTION"
60	MethodCommitTransaction   string = "COMMIT_TRANSACTION"
61	MethodBatchCreateSession  string = "BATCH_CREATE_SESSION"
62	MethodCreateSession       string = "CREATE_SESSION"
63	MethodDeleteSession       string = "DELETE_SESSION"
64	MethodGetSession          string = "GET_SESSION"
65	MethodExecuteSql          string = "EXECUTE_SQL"
66	MethodExecuteStreamingSql string = "EXECUTE_STREAMING_SQL"
67	MethodExecuteBatchDml     string = "EXECUTE_BATCH_DML"
68)
69
70// StatementResult represents a mocked result on the test server. The result is
71// either of: a ResultSet, an update count or an error.
72type StatementResult struct {
73	Type        StatementResultType
74	Err         error
75	ResultSet   *spannerpb.ResultSet
76	UpdateCount int64
77}
78
79// PartialResultSetExecutionTime represents execution times and errors that
80// should be used when a PartialResult at the specified resume token is to
81// be returned.
82type PartialResultSetExecutionTime struct {
83	ResumeToken   []byte
84	ExecutionTime time.Duration
85	Err           error
86}
87
88// Converts a ResultSet to a PartialResultSet. This method is used to convert
89// a mocked result to a PartialResultSet when one of the streaming methods are
90// called.
91func (s *StatementResult) toPartialResultSets(resumeToken []byte) (result []*spannerpb.PartialResultSet, err error) {
92	var startIndex uint64
93	if len(resumeToken) > 0 {
94		if startIndex, err = DecodeResumeToken(resumeToken); err != nil {
95			return nil, err
96		}
97	}
98
99	totalRows := uint64(len(s.ResultSet.Rows))
100	for {
101		rowCount := min(totalRows-startIndex, uint64(MaxRowsPerPartialResultSet))
102		rows := s.ResultSet.Rows[startIndex : startIndex+rowCount]
103		values := make([]*structpb.Value,
104			len(rows)*len(s.ResultSet.Metadata.RowType.Fields))
105		var idx int
106		for _, row := range rows {
107			for colIdx := range s.ResultSet.Metadata.RowType.Fields {
108				values[idx] = row.Values[colIdx]
109				idx++
110			}
111		}
112		result = append(result, &spannerpb.PartialResultSet{
113			Metadata:    s.ResultSet.Metadata,
114			Values:      values,
115			ResumeToken: EncodeResumeToken(startIndex + rowCount),
116		})
117		startIndex += rowCount
118		if startIndex == totalRows {
119			break
120		}
121	}
122	return result, nil
123}
124
125func min(x, y uint64) uint64 {
126	if x > y {
127		return y
128	}
129	return x
130}
131
132func (s *StatementResult) updateCountToPartialResultSet(exact bool) *spannerpb.PartialResultSet {
133	return &spannerpb.PartialResultSet{
134		Stats: s.convertUpdateCountToResultSet(exact).Stats,
135	}
136}
137
138// Converts an update count to a ResultSet, as DML statements also return the
139// update count as the statistics of a ResultSet.
140func (s *StatementResult) convertUpdateCountToResultSet(exact bool) *spannerpb.ResultSet {
141	if exact {
142		return &spannerpb.ResultSet{
143			Stats: &spannerpb.ResultSetStats{
144				RowCount: &spannerpb.ResultSetStats_RowCountExact{
145					RowCountExact: s.UpdateCount,
146				},
147			},
148		}
149	}
150	return &spannerpb.ResultSet{
151		Stats: &spannerpb.ResultSetStats{
152			RowCount: &spannerpb.ResultSetStats_RowCountLowerBound{
153				RowCountLowerBound: s.UpdateCount,
154			},
155		},
156	}
157}
158
159// SimulatedExecutionTime represents the time the execution of a method
160// should take, and any errors that should be returned by the method.
161type SimulatedExecutionTime struct {
162	MinimumExecutionTime time.Duration
163	RandomExecutionTime  time.Duration
164	Errors               []error
165	// Keep error after execution. The error will continue to be returned until
166	// it is cleared.
167	KeepError bool
168}
169
170// InMemSpannerServer contains the SpannerServer interface plus a couple
171// of specific methods for adding mocked results and resetting the server.
172type InMemSpannerServer interface {
173	spannerpb.SpannerServer
174
175	// Stops this server.
176	Stop()
177
178	// Resets the in-mem server to its default state, deleting all sessions and
179	// transactions that have been created on the server. Mocked results are
180	// not deleted.
181	Reset()
182
183	// Sets an error that will be returned by the next server call. The server
184	// call will also automatically clear the error.
185	SetError(err error)
186
187	// Puts a mocked result on the server for a specific sql statement. The
188	// server does not parse the SQL string in any way, it is merely used as
189	// a key to the mocked result. The result will be used for all methods that
190	// expect a SQL statement, including (batch) DML methods.
191	PutStatementResult(sql string, result *StatementResult) error
192
193	// Puts a mocked result on the server for a specific partition token. The
194	// result will only be used for query requests that specify a partition
195	// token.
196	PutPartitionResult(partitionToken []byte, result *StatementResult) error
197
198	// Adds a PartialResultSetExecutionTime to the server that should be returned
199	// for the specified SQL string.
200	AddPartialResultSetError(sql string, err PartialResultSetExecutionTime)
201
202	// Removes a mocked result on the server for a specific sql statement.
203	RemoveStatementResult(sql string)
204
205	// Aborts the specified transaction . This method can be used to test
206	// transaction retry logic.
207	AbortTransaction(id []byte)
208
209	// Puts a simulated execution time for one of the Spanner methods.
210	PutExecutionTime(method string, executionTime SimulatedExecutionTime)
211	// Freeze stalls all requests.
212	Freeze()
213	// Unfreeze restores processing requests.
214	Unfreeze()
215
216	TotalSessionsCreated() uint
217	TotalSessionsDeleted() uint
218	SetMaxSessionsReturnedByServerPerBatchRequest(sessionCount int32)
219	SetMaxSessionsReturnedByServerInTotal(sessionCount int32)
220
221	ReceivedRequests() chan interface{}
222	DumpSessions() map[string]bool
223	ClearPings()
224	DumpPings() []string
225}
226
227type inMemSpannerServer struct {
228	// Embed for forward compatibility.
229	// Tests will keep working if more methods are added
230	// in the future.
231	spannerpb.SpannerServer
232
233	mu sync.Mutex
234	// Set to true when this server been stopped. This is the end state of a
235	// server, a stopped server cannot be restarted.
236	stopped bool
237	// If set, all calls return this error.
238	err error
239	// The mock server creates session IDs using this counter.
240	sessionCounter uint64
241	// The sessions that have been created on this mock server.
242	sessions map[string]*spannerpb.Session
243	// Last use times per session.
244	sessionLastUseTime map[string]time.Time
245	// The mock server creates transaction IDs per session using these
246	// counters.
247	transactionCounters map[string]*uint64
248	// The transactions that have been created on this mock server.
249	transactions map[string]*spannerpb.Transaction
250	// The transactions that have been (manually) aborted on the server.
251	abortedTransactions map[string]bool
252	// The transactions that are marked as PartitionedDMLTransaction
253	partitionedDmlTransactions map[string]bool
254	// The mocked results for this server.
255	statementResults map[string]*StatementResult
256	partitionResults map[string]*StatementResult
257	// The simulated execution times per method.
258	executionTimes map[string]*SimulatedExecutionTime
259	// The simulated errors for partial result sets
260	partialResultSetErrors map[string][]*PartialResultSetExecutionTime
261
262	totalSessionsCreated uint
263	totalSessionsDeleted uint
264	// The maximum number of sessions that will be created per batch request.
265	maxSessionsReturnedByServerPerBatchRequest int32
266	maxSessionsReturnedByServerInTotal         int32
267	receivedRequests                           chan interface{}
268	// Session ping history.
269	pings []string
270
271	// Server will stall on any requests.
272	freezed chan struct{}
273}
274
275// NewInMemSpannerServer creates a new in-mem test server.
276func NewInMemSpannerServer() InMemSpannerServer {
277	res := &inMemSpannerServer{}
278	res.initDefaults()
279	res.statementResults = make(map[string]*StatementResult)
280	res.partitionResults = make(map[string]*StatementResult)
281	res.executionTimes = make(map[string]*SimulatedExecutionTime)
282	res.partialResultSetErrors = make(map[string][]*PartialResultSetExecutionTime)
283	res.receivedRequests = make(chan interface{}, 1000000)
284	// Produce a closed channel, so the default action of ready is to not block.
285	res.Freeze()
286	res.Unfreeze()
287	return res
288}
289
290func (s *inMemSpannerServer) Stop() {
291	s.mu.Lock()
292	defer s.mu.Unlock()
293	s.stopped = true
294	close(s.receivedRequests)
295}
296
297// Resets the test server to its initial state, deleting all sessions and
298// transactions that have been created on the server. This method will not
299// remove mocked results.
300func (s *inMemSpannerServer) Reset() {
301	s.mu.Lock()
302	defer s.mu.Unlock()
303	close(s.receivedRequests)
304	s.receivedRequests = make(chan interface{}, 1000000)
305	s.initDefaults()
306}
307
308func (s *inMemSpannerServer) SetError(err error) {
309	s.mu.Lock()
310	defer s.mu.Unlock()
311	s.err = err
312}
313
314// Registers a mocked result for a SQL statement on the server.
315func (s *inMemSpannerServer) PutStatementResult(sql string, result *StatementResult) error {
316	s.mu.Lock()
317	defer s.mu.Unlock()
318	s.statementResults[sql] = result
319	return nil
320}
321
322func (s *inMemSpannerServer) RemoveStatementResult(sql string) {
323	s.mu.Lock()
324	defer s.mu.Unlock()
325	delete(s.statementResults, sql)
326}
327
328// Registers a mocked result for a partition token on the server.
329func (s *inMemSpannerServer) PutPartitionResult(partitionToken []byte, result *StatementResult) error {
330	tokenString := string(partitionToken)
331	s.mu.Lock()
332	defer s.mu.Unlock()
333	s.partitionResults[tokenString] = result
334	return nil
335}
336
337func (s *inMemSpannerServer) AbortTransaction(id []byte) {
338	s.mu.Lock()
339	defer s.mu.Unlock()
340	s.abortedTransactions[string(id)] = true
341}
342
343func (s *inMemSpannerServer) PutExecutionTime(method string, executionTime SimulatedExecutionTime) {
344	s.mu.Lock()
345	defer s.mu.Unlock()
346	s.executionTimes[method] = &executionTime
347}
348
349func (s *inMemSpannerServer) AddPartialResultSetError(sql string, partialResultSetError PartialResultSetExecutionTime) {
350	s.mu.Lock()
351	defer s.mu.Unlock()
352	s.partialResultSetErrors[sql] = append(s.partialResultSetErrors[sql], &partialResultSetError)
353}
354
355// Freeze stalls all requests.
356func (s *inMemSpannerServer) Freeze() {
357	s.mu.Lock()
358	defer s.mu.Unlock()
359	s.freezed = make(chan struct{})
360}
361
362// Unfreeze restores processing requests.
363func (s *inMemSpannerServer) Unfreeze() {
364	s.mu.Lock()
365	defer s.mu.Unlock()
366	close(s.freezed)
367}
368
369// ready checks conditions before executing requests
370func (s *inMemSpannerServer) ready() {
371	s.mu.Lock()
372	freezed := s.freezed
373	s.mu.Unlock()
374	// check if server should be freezed
375	<-freezed
376}
377
378func (s *inMemSpannerServer) TotalSessionsCreated() uint {
379	s.mu.Lock()
380	defer s.mu.Unlock()
381	return s.totalSessionsCreated
382}
383
384func (s *inMemSpannerServer) TotalSessionsDeleted() uint {
385	s.mu.Lock()
386	defer s.mu.Unlock()
387	return s.totalSessionsDeleted
388}
389
390func (s *inMemSpannerServer) SetMaxSessionsReturnedByServerPerBatchRequest(sessionCount int32) {
391	s.mu.Lock()
392	defer s.mu.Unlock()
393	s.maxSessionsReturnedByServerPerBatchRequest = sessionCount
394}
395
396func (s *inMemSpannerServer) SetMaxSessionsReturnedByServerInTotal(sessionCount int32) {
397	s.mu.Lock()
398	defer s.mu.Unlock()
399	s.maxSessionsReturnedByServerInTotal = sessionCount
400}
401
402func (s *inMemSpannerServer) ReceivedRequests() chan interface{} {
403	return s.receivedRequests
404}
405
406// ClearPings clears the ping history from the server.
407func (s *inMemSpannerServer) ClearPings() {
408	s.mu.Lock()
409	defer s.mu.Unlock()
410	s.pings = nil
411}
412
413// DumpPings dumps the ping history.
414func (s *inMemSpannerServer) DumpPings() []string {
415	s.mu.Lock()
416	defer s.mu.Unlock()
417	return append([]string(nil), s.pings...)
418}
419
420// DumpSessions dumps the internal session table.
421func (s *inMemSpannerServer) DumpSessions() map[string]bool {
422	s.mu.Lock()
423	defer s.mu.Unlock()
424	st := map[string]bool{}
425	for s := range s.sessions {
426		st[s] = true
427	}
428	return st
429}
430
431func (s *inMemSpannerServer) initDefaults() {
432	s.sessionCounter = 0
433	s.maxSessionsReturnedByServerPerBatchRequest = 100
434	s.sessions = make(map[string]*spannerpb.Session)
435	s.sessionLastUseTime = make(map[string]time.Time)
436	s.transactions = make(map[string]*spannerpb.Transaction)
437	s.abortedTransactions = make(map[string]bool)
438	s.partitionedDmlTransactions = make(map[string]bool)
439	s.transactionCounters = make(map[string]*uint64)
440}
441
442func (s *inMemSpannerServer) generateSessionNameLocked(database string) string {
443	s.sessionCounter++
444	return fmt.Sprintf("%s/sessions/%d", database, s.sessionCounter)
445}
446
447func (s *inMemSpannerServer) findSession(name string) (*spannerpb.Session, error) {
448	s.mu.Lock()
449	defer s.mu.Unlock()
450	session := s.sessions[name]
451	if session == nil {
452		return nil, newSessionNotFoundError(name)
453	}
454	return session, nil
455}
456
457// sessionResourceType is the type name of Spanner sessions.
458const sessionResourceType = "type.googleapis.com/google.spanner.v1.Session"
459
460func newSessionNotFoundError(name string) error {
461	s := gstatus.Newf(codes.NotFound, "Session not found: Session with id %s not found", name)
462	s, _ = s.WithDetails(&errdetails.ResourceInfo{ResourceType: sessionResourceType, ResourceName: name})
463	return s.Err()
464}
465
466func (s *inMemSpannerServer) updateSessionLastUseTime(session string) {
467	s.mu.Lock()
468	defer s.mu.Unlock()
469	s.sessionLastUseTime[session] = time.Now()
470}
471
472func getCurrentTimestamp() *timestamp.Timestamp {
473	t := time.Now()
474	return &timestamp.Timestamp{Seconds: t.Unix(), Nanos: int32(t.Nanosecond())}
475}
476
477// Gets the transaction id from the transaction selector. If the selector
478// specifies that a new transaction should be started, this method will start
479// a new transaction and return the id of that transaction.
480func (s *inMemSpannerServer) getTransactionID(session *spannerpb.Session, txSelector *spannerpb.TransactionSelector) []byte {
481	var res []byte
482	if txSelector.GetBegin() != nil {
483		// Start a new transaction.
484		res = s.beginTransaction(session, txSelector.GetBegin()).Id
485	} else if txSelector.GetId() != nil {
486		res = txSelector.GetId()
487	}
488	return res
489}
490
491func (s *inMemSpannerServer) generateTransactionName(session string) string {
492	s.mu.Lock()
493	defer s.mu.Unlock()
494	counter, ok := s.transactionCounters[session]
495	if !ok {
496		counter = new(uint64)
497		s.transactionCounters[session] = counter
498	}
499	*counter++
500	return fmt.Sprintf("%s/transactions/%d", session, *counter)
501}
502
503func (s *inMemSpannerServer) beginTransaction(session *spannerpb.Session, options *spannerpb.TransactionOptions) *spannerpb.Transaction {
504	id := s.generateTransactionName(session.Name)
505	res := &spannerpb.Transaction{
506		Id:            []byte(id),
507		ReadTimestamp: getCurrentTimestamp(),
508	}
509	s.mu.Lock()
510	s.transactions[id] = res
511	s.partitionedDmlTransactions[id] = options.GetPartitionedDml() != nil
512	s.mu.Unlock()
513	return res
514}
515
516func (s *inMemSpannerServer) getTransactionByID(id []byte) (*spannerpb.Transaction, error) {
517	s.mu.Lock()
518	defer s.mu.Unlock()
519	tx, ok := s.transactions[string(id)]
520	if !ok {
521		return nil, gstatus.Error(codes.NotFound, "Transaction not found")
522	}
523	aborted, ok := s.abortedTransactions[string(id)]
524	if ok && aborted {
525		return nil, newAbortedErrorWithMinimalRetryDelay()
526	}
527	return tx, nil
528}
529
530func newAbortedErrorWithMinimalRetryDelay() error {
531	st := gstatus.New(codes.Aborted, "Transaction has been aborted")
532	retry := &errdetails.RetryInfo{
533		RetryDelay: ptypes.DurationProto(time.Nanosecond),
534	}
535	st, _ = st.WithDetails(retry)
536	return st.Err()
537}
538
539func (s *inMemSpannerServer) removeTransaction(tx *spannerpb.Transaction) {
540	s.mu.Lock()
541	defer s.mu.Unlock()
542	delete(s.transactions, string(tx.Id))
543	delete(s.partitionedDmlTransactions, string(tx.Id))
544}
545
546func (s *inMemSpannerServer) getPartitionResult(partitionToken []byte) (*StatementResult, error) {
547	tokenString := string(partitionToken)
548	s.mu.Lock()
549	defer s.mu.Unlock()
550	result, ok := s.partitionResults[tokenString]
551	if !ok {
552		return nil, gstatus.Error(codes.Internal, fmt.Sprintf("No result found for partition token %v", tokenString))
553	}
554	return result, nil
555}
556
557func (s *inMemSpannerServer) getStatementResult(sql string) (*StatementResult, error) {
558	s.mu.Lock()
559	defer s.mu.Unlock()
560	result, ok := s.statementResults[sql]
561	if !ok {
562		return nil, gstatus.Error(codes.Internal, fmt.Sprintf("No result found for statement %v", sql))
563	}
564	return result, nil
565}
566
567func (s *inMemSpannerServer) simulateExecutionTime(method string, req interface{}) error {
568	s.mu.Lock()
569	if s.stopped {
570		s.mu.Unlock()
571		return gstatus.Error(codes.Unavailable, "server has been stopped")
572	}
573	s.receivedRequests <- req
574	s.mu.Unlock()
575	s.ready()
576	s.mu.Lock()
577	if s.err != nil {
578		err := s.err
579		s.err = nil
580		s.mu.Unlock()
581		return err
582	}
583	executionTime, ok := s.executionTimes[method]
584	s.mu.Unlock()
585	if ok {
586		var randTime int64
587		if executionTime.RandomExecutionTime > 0 {
588			randTime = rand.Int63n(int64(executionTime.RandomExecutionTime))
589		}
590		totalExecutionTime := time.Duration(int64(executionTime.MinimumExecutionTime) + randTime)
591		<-time.After(totalExecutionTime)
592		s.mu.Lock()
593		if executionTime.Errors != nil && len(executionTime.Errors) > 0 {
594			err := executionTime.Errors[0]
595			if !executionTime.KeepError {
596				executionTime.Errors = executionTime.Errors[1:]
597			}
598			s.mu.Unlock()
599			return err
600		}
601		s.mu.Unlock()
602	}
603	return nil
604}
605
606func (s *inMemSpannerServer) CreateSession(ctx context.Context, req *spannerpb.CreateSessionRequest) (*spannerpb.Session, error) {
607	if err := s.simulateExecutionTime(MethodCreateSession, req); err != nil {
608		return nil, err
609	}
610	if req.Database == "" {
611		return nil, gstatus.Error(codes.InvalidArgument, "Missing database")
612	}
613	s.mu.Lock()
614	defer s.mu.Unlock()
615	if s.maxSessionsReturnedByServerInTotal > int32(0) && int32(len(s.sessions)) == s.maxSessionsReturnedByServerInTotal {
616		return nil, gstatus.Error(codes.ResourceExhausted, "No more sessions available")
617	}
618	sessionName := s.generateSessionNameLocked(req.Database)
619	ts := getCurrentTimestamp()
620	session := &spannerpb.Session{Name: sessionName, CreateTime: ts, ApproximateLastUseTime: ts}
621	s.totalSessionsCreated++
622	s.sessions[sessionName] = session
623	return session, nil
624}
625
626func (s *inMemSpannerServer) BatchCreateSessions(ctx context.Context, req *spannerpb.BatchCreateSessionsRequest) (*spannerpb.BatchCreateSessionsResponse, error) {
627	if err := s.simulateExecutionTime(MethodBatchCreateSession, req); err != nil {
628		return nil, err
629	}
630	if req.Database == "" {
631		return nil, gstatus.Error(codes.InvalidArgument, "Missing database")
632	}
633	if req.SessionCount <= 0 {
634		return nil, gstatus.Error(codes.InvalidArgument, "Session count must be >= 0")
635	}
636	sessionsToCreate := req.SessionCount
637	s.mu.Lock()
638	defer s.mu.Unlock()
639	if s.maxSessionsReturnedByServerInTotal > int32(0) && int32(len(s.sessions)) >= s.maxSessionsReturnedByServerInTotal {
640		return nil, gstatus.Error(codes.ResourceExhausted, "No more sessions available")
641	}
642	if sessionsToCreate > s.maxSessionsReturnedByServerPerBatchRequest {
643		sessionsToCreate = s.maxSessionsReturnedByServerPerBatchRequest
644	}
645	if s.maxSessionsReturnedByServerInTotal > int32(0) && (sessionsToCreate+int32(len(s.sessions))) > s.maxSessionsReturnedByServerInTotal {
646		sessionsToCreate = s.maxSessionsReturnedByServerInTotal - int32(len(s.sessions))
647	}
648	sessions := make([]*spannerpb.Session, sessionsToCreate)
649	for i := int32(0); i < sessionsToCreate; i++ {
650		sessionName := s.generateSessionNameLocked(req.Database)
651		ts := getCurrentTimestamp()
652		sessions[i] = &spannerpb.Session{Name: sessionName, CreateTime: ts, ApproximateLastUseTime: ts}
653		s.totalSessionsCreated++
654		s.sessions[sessionName] = sessions[i]
655	}
656	return &spannerpb.BatchCreateSessionsResponse{Session: sessions}, nil
657}
658
659func (s *inMemSpannerServer) GetSession(ctx context.Context, req *spannerpb.GetSessionRequest) (*spannerpb.Session, error) {
660	if err := s.simulateExecutionTime(MethodGetSession, req); err != nil {
661		return nil, err
662	}
663	if req.Name == "" {
664		return nil, gstatus.Error(codes.InvalidArgument, "Missing session name")
665	}
666	session, err := s.findSession(req.Name)
667	if err != nil {
668		return nil, err
669	}
670	return session, nil
671}
672
673func (s *inMemSpannerServer) ListSessions(ctx context.Context, req *spannerpb.ListSessionsRequest) (*spannerpb.ListSessionsResponse, error) {
674	s.mu.Lock()
675	if s.stopped {
676		s.mu.Unlock()
677		return nil, gstatus.Error(codes.Unavailable, "server has been stopped")
678	}
679	s.receivedRequests <- req
680	s.mu.Unlock()
681	if req.Database == "" {
682		return nil, gstatus.Error(codes.InvalidArgument, "Missing database")
683	}
684	expectedSessionName := req.Database + "/sessions/"
685	var sessions []*spannerpb.Session
686	s.mu.Lock()
687	for _, session := range s.sessions {
688		if strings.Index(session.Name, expectedSessionName) == 0 {
689			sessions = append(sessions, session)
690		}
691	}
692	s.mu.Unlock()
693	sort.Slice(sessions[:], func(i, j int) bool {
694		return sessions[i].Name < sessions[j].Name
695	})
696	res := &spannerpb.ListSessionsResponse{Sessions: sessions}
697	return res, nil
698}
699
700func (s *inMemSpannerServer) DeleteSession(ctx context.Context, req *spannerpb.DeleteSessionRequest) (*emptypb.Empty, error) {
701	if err := s.simulateExecutionTime(MethodDeleteSession, req); err != nil {
702		return nil, err
703	}
704	if req.Name == "" {
705		return nil, gstatus.Error(codes.InvalidArgument, "Missing session name")
706	}
707	if _, err := s.findSession(req.Name); err != nil {
708		return nil, err
709	}
710	s.mu.Lock()
711	defer s.mu.Unlock()
712	s.totalSessionsDeleted++
713	delete(s.sessions, req.Name)
714	return &emptypb.Empty{}, nil
715}
716
717func (s *inMemSpannerServer) ExecuteSql(ctx context.Context, req *spannerpb.ExecuteSqlRequest) (*spannerpb.ResultSet, error) {
718	if err := s.simulateExecutionTime(MethodExecuteSql, req); err != nil {
719		return nil, err
720	}
721	if req.Sql == "SELECT 1" {
722		s.mu.Lock()
723		s.pings = append(s.pings, req.Session)
724		s.mu.Unlock()
725	}
726	if req.Session == "" {
727		return nil, gstatus.Error(codes.InvalidArgument, "Missing session name")
728	}
729	session, err := s.findSession(req.Session)
730	if err != nil {
731		return nil, err
732	}
733	var id []byte
734	s.updateSessionLastUseTime(session.Name)
735	if id = s.getTransactionID(session, req.Transaction); id != nil {
736		_, err = s.getTransactionByID(id)
737		if err != nil {
738			return nil, err
739		}
740	}
741	var statementResult *StatementResult
742	if req.PartitionToken != nil {
743		statementResult, err = s.getPartitionResult(req.PartitionToken)
744	} else {
745		statementResult, err = s.getStatementResult(req.Sql)
746	}
747	if err != nil {
748		return nil, err
749	}
750	s.mu.Lock()
751	isPartitionedDml := s.partitionedDmlTransactions[string(id)]
752	s.mu.Unlock()
753	switch statementResult.Type {
754	case StatementResultError:
755		return nil, statementResult.Err
756	case StatementResultResultSet:
757		return statementResult.ResultSet, nil
758	case StatementResultUpdateCount:
759		return statementResult.convertUpdateCountToResultSet(!isPartitionedDml), nil
760	}
761	return nil, gstatus.Error(codes.Internal, "Unknown result type")
762}
763
764func (s *inMemSpannerServer) ExecuteStreamingSql(req *spannerpb.ExecuteSqlRequest, stream spannerpb.Spanner_ExecuteStreamingSqlServer) error {
765	if err := s.simulateExecutionTime(MethodExecuteStreamingSql, req); err != nil {
766		return err
767	}
768	if req.Session == "" {
769		return gstatus.Error(codes.InvalidArgument, "Missing session name")
770	}
771	session, err := s.findSession(req.Session)
772	if err != nil {
773		return err
774	}
775	s.updateSessionLastUseTime(session.Name)
776	var id []byte
777	if id = s.getTransactionID(session, req.Transaction); id != nil {
778		_, err = s.getTransactionByID(id)
779		if err != nil {
780			return err
781		}
782	}
783	var statementResult *StatementResult
784	if req.PartitionToken != nil {
785		statementResult, err = s.getPartitionResult(req.PartitionToken)
786	} else {
787		statementResult, err = s.getStatementResult(req.Sql)
788	}
789	if err != nil {
790		return err
791	}
792	s.mu.Lock()
793	isPartitionedDml := s.partitionedDmlTransactions[string(id)]
794	s.mu.Unlock()
795	switch statementResult.Type {
796	case StatementResultError:
797		return statementResult.Err
798	case StatementResultResultSet:
799		parts, err := statementResult.toPartialResultSets(req.ResumeToken)
800		if err != nil {
801			return err
802		}
803		var nextPartialResultSetError *PartialResultSetExecutionTime
804		s.mu.Lock()
805		pErrors := s.partialResultSetErrors[req.Sql]
806		if len(pErrors) > 0 {
807			nextPartialResultSetError = pErrors[0]
808			s.partialResultSetErrors[req.Sql] = pErrors[1:]
809		}
810		s.mu.Unlock()
811		for _, part := range parts {
812			if nextPartialResultSetError != nil && bytes.Equal(part.ResumeToken, nextPartialResultSetError.ResumeToken) {
813				if nextPartialResultSetError.ExecutionTime > 0 {
814					<-time.After(nextPartialResultSetError.ExecutionTime)
815				}
816				if nextPartialResultSetError.Err != nil {
817					return nextPartialResultSetError.Err
818				}
819			}
820			if err := stream.Send(part); err != nil {
821				return err
822			}
823		}
824		return nil
825	case StatementResultUpdateCount:
826		part := statementResult.updateCountToPartialResultSet(!isPartitionedDml)
827		if err := stream.Send(part); err != nil {
828			return err
829		}
830		return nil
831	}
832	return gstatus.Error(codes.Internal, "Unknown result type")
833}
834
835func (s *inMemSpannerServer) ExecuteBatchDml(ctx context.Context, req *spannerpb.ExecuteBatchDmlRequest) (*spannerpb.ExecuteBatchDmlResponse, error) {
836	if err := s.simulateExecutionTime(MethodExecuteBatchDml, req); err != nil {
837		return nil, err
838	}
839	if req.Session == "" {
840		return nil, gstatus.Error(codes.InvalidArgument, "Missing session name")
841	}
842	session, err := s.findSession(req.Session)
843	if err != nil {
844		return nil, err
845	}
846	s.updateSessionLastUseTime(session.Name)
847	var id []byte
848	if id = s.getTransactionID(session, req.Transaction); id != nil {
849		_, err = s.getTransactionByID(id)
850		if err != nil {
851			return nil, err
852		}
853	}
854	s.mu.Lock()
855	isPartitionedDml := s.partitionedDmlTransactions[string(id)]
856	s.mu.Unlock()
857	resp := &spannerpb.ExecuteBatchDmlResponse{}
858	resp.ResultSets = make([]*spannerpb.ResultSet, len(req.Statements))
859	resp.Status = &status.Status{Code: int32(codes.OK)}
860	for idx, batchStatement := range req.Statements {
861		statementResult, err := s.getStatementResult(batchStatement.Sql)
862		if err != nil {
863			return nil, err
864		}
865		switch statementResult.Type {
866		case StatementResultError:
867			resp.Status = &status.Status{Code: int32(gstatus.Code(statementResult.Err)), Message: statementResult.Err.Error()}
868			resp.ResultSets = resp.ResultSets[:idx]
869			return resp, nil
870		case StatementResultResultSet:
871			return nil, gstatus.Error(codes.InvalidArgument, fmt.Sprintf("Not an update statement: %v", batchStatement.Sql))
872		case StatementResultUpdateCount:
873			resp.ResultSets[idx] = statementResult.convertUpdateCountToResultSet(!isPartitionedDml)
874		}
875	}
876	return resp, nil
877}
878
879func (s *inMemSpannerServer) Read(ctx context.Context, req *spannerpb.ReadRequest) (*spannerpb.ResultSet, error) {
880	s.mu.Lock()
881	if s.stopped {
882		s.mu.Unlock()
883		return nil, gstatus.Error(codes.Unavailable, "server has been stopped")
884	}
885	s.receivedRequests <- req
886	s.mu.Unlock()
887	return nil, gstatus.Error(codes.Unimplemented, "Method not yet implemented")
888}
889
890func (s *inMemSpannerServer) StreamingRead(req *spannerpb.ReadRequest, stream spannerpb.Spanner_StreamingReadServer) error {
891	s.mu.Lock()
892	if s.stopped {
893		s.mu.Unlock()
894		return gstatus.Error(codes.Unavailable, "server has been stopped")
895	}
896	s.receivedRequests <- req
897	s.mu.Unlock()
898	return gstatus.Error(codes.Unimplemented, "Method not yet implemented")
899}
900
901func (s *inMemSpannerServer) BeginTransaction(ctx context.Context, req *spannerpb.BeginTransactionRequest) (*spannerpb.Transaction, error) {
902	if err := s.simulateExecutionTime(MethodBeginTransaction, req); err != nil {
903		return nil, err
904	}
905	if req.Session == "" {
906		return nil, gstatus.Error(codes.InvalidArgument, "Missing session name")
907	}
908	session, err := s.findSession(req.Session)
909	if err != nil {
910		return nil, err
911	}
912	s.updateSessionLastUseTime(session.Name)
913	tx := s.beginTransaction(session, req.Options)
914	return tx, nil
915}
916
917func (s *inMemSpannerServer) Commit(ctx context.Context, req *spannerpb.CommitRequest) (*spannerpb.CommitResponse, error) {
918	if err := s.simulateExecutionTime(MethodCommitTransaction, req); err != nil {
919		return nil, err
920	}
921	if req.Session == "" {
922		return nil, gstatus.Error(codes.InvalidArgument, "Missing session name")
923	}
924	session, err := s.findSession(req.Session)
925	if err != nil {
926		return nil, err
927	}
928	s.updateSessionLastUseTime(session.Name)
929	var tx *spannerpb.Transaction
930	if req.GetSingleUseTransaction() != nil {
931		tx = s.beginTransaction(session, req.GetSingleUseTransaction())
932	} else if req.GetTransactionId() != nil {
933		tx, err = s.getTransactionByID(req.GetTransactionId())
934		if err != nil {
935			return nil, err
936		}
937	} else {
938		return nil, gstatus.Error(codes.InvalidArgument, "Missing transaction in commit request")
939	}
940	s.removeTransaction(tx)
941	return &spannerpb.CommitResponse{CommitTimestamp: getCurrentTimestamp()}, nil
942}
943
944func (s *inMemSpannerServer) Rollback(ctx context.Context, req *spannerpb.RollbackRequest) (*emptypb.Empty, error) {
945	s.mu.Lock()
946	if s.stopped {
947		s.mu.Unlock()
948		return nil, gstatus.Error(codes.Unavailable, "server has been stopped")
949	}
950	s.receivedRequests <- req
951	s.mu.Unlock()
952	if req.Session == "" {
953		return nil, gstatus.Error(codes.InvalidArgument, "Missing session name")
954	}
955	session, err := s.findSession(req.Session)
956	if err != nil {
957		return nil, err
958	}
959	s.updateSessionLastUseTime(session.Name)
960	tx, err := s.getTransactionByID(req.TransactionId)
961	if err != nil {
962		return nil, err
963	}
964	s.removeTransaction(tx)
965	return &emptypb.Empty{}, nil
966}
967
968func (s *inMemSpannerServer) PartitionQuery(ctx context.Context, req *spannerpb.PartitionQueryRequest) (*spannerpb.PartitionResponse, error) {
969	s.mu.Lock()
970	if s.stopped {
971		s.mu.Unlock()
972		return nil, gstatus.Error(codes.Unavailable, "server has been stopped")
973	}
974	s.receivedRequests <- req
975	s.mu.Unlock()
976	if req.Session == "" {
977		return nil, gstatus.Error(codes.InvalidArgument, "Missing session name")
978	}
979	session, err := s.findSession(req.Session)
980	if err != nil {
981		return nil, err
982	}
983	var id []byte
984	var tx *spannerpb.Transaction
985	s.updateSessionLastUseTime(session.Name)
986	if id = s.getTransactionID(session, req.Transaction); id != nil {
987		tx, err = s.getTransactionByID(id)
988		if err != nil {
989			return nil, err
990		}
991	}
992	var partitions []*spannerpb.Partition
993	for i := int64(0); i < req.PartitionOptions.MaxPartitions; i++ {
994		token := make([]byte, 10)
995		_, err := rand.Read(token)
996		if err != nil {
997			return nil, gstatus.Error(codes.Internal, "failed to generate random partition token")
998		}
999		partitions = append(partitions, &spannerpb.Partition{PartitionToken: token})
1000	}
1001	return &spannerpb.PartitionResponse{
1002		Partitions:  partitions,
1003		Transaction: tx,
1004	}, nil
1005}
1006
1007func (s *inMemSpannerServer) PartitionRead(ctx context.Context, req *spannerpb.PartitionReadRequest) (*spannerpb.PartitionResponse, error) {
1008	s.mu.Lock()
1009	if s.stopped {
1010		s.mu.Unlock()
1011		return nil, gstatus.Error(codes.Unavailable, "server has been stopped")
1012	}
1013	s.receivedRequests <- req
1014	s.mu.Unlock()
1015	return nil, gstatus.Error(codes.Unimplemented, "Method not yet implemented")
1016}
1017