1// Copyright 2019 Google LLC
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     https://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15package testutil
16
17import (
18	"bytes"
19	"context"
20	"encoding/binary"
21	"fmt"
22	"math/rand"
23	"sort"
24	"strings"
25	"sync"
26	"time"
27
28	"github.com/golang/protobuf/ptypes"
29	emptypb "github.com/golang/protobuf/ptypes/empty"
30	structpb "github.com/golang/protobuf/ptypes/struct"
31	"github.com/golang/protobuf/ptypes/timestamp"
32	"google.golang.org/genproto/googleapis/rpc/errdetails"
33	"google.golang.org/genproto/googleapis/rpc/status"
34	spannerpb "google.golang.org/genproto/googleapis/spanner/v1"
35	"google.golang.org/grpc/codes"
36	gstatus "google.golang.org/grpc/status"
37)
38
39var (
40	// KvMeta is the Metadata for mocked KV table.
41	KvMeta = spannerpb.ResultSetMetadata{
42		RowType: &spannerpb.StructType{
43			Fields: []*spannerpb.StructType_Field{
44				{
45					Name: "Key",
46					Type: &spannerpb.Type{Code: spannerpb.TypeCode_STRING},
47				},
48				{
49					Name: "Value",
50					Type: &spannerpb.Type{Code: spannerpb.TypeCode_STRING},
51				},
52			},
53		},
54	}
55)
56
57// StatementResultType indicates the type of result returned by a SQL
58// statement.
59type StatementResultType int
60
61const (
62	// StatementResultError indicates that the sql statement returns an error.
63	StatementResultError StatementResultType = 0
64	// StatementResultResultSet indicates that the sql statement returns a
65	// result set.
66	StatementResultResultSet StatementResultType = 1
67	// StatementResultUpdateCount indicates that the sql statement returns an
68	// update count.
69	StatementResultUpdateCount StatementResultType = 2
70	// MaxRowsPerPartialResultSet is the maximum number of rows returned in
71	// each PartialResultSet. This number is deliberately set to a low value to
72	// ensure that most queries return more than one PartialResultSet.
73	MaxRowsPerPartialResultSet = 1
74)
75
76// The method names that can be used to register execution times and errors.
77const (
78	MethodBeginTransaction    string = "BEGIN_TRANSACTION"
79	MethodCommitTransaction   string = "COMMIT_TRANSACTION"
80	MethodBatchCreateSession  string = "BATCH_CREATE_SESSION"
81	MethodCreateSession       string = "CREATE_SESSION"
82	MethodDeleteSession       string = "DELETE_SESSION"
83	MethodGetSession          string = "GET_SESSION"
84	MethodExecuteSql          string = "EXECUTE_SQL"
85	MethodExecuteStreamingSql string = "EXECUTE_STREAMING_SQL"
86	MethodExecuteBatchDml     string = "EXECUTE_BATCH_DML"
87	MethodStreamingRead       string = "EXECUTE_STREAMING_READ"
88)
89
90// StatementResult represents a mocked result on the test server. The result is
91// either of: a ResultSet, an update count or an error.
92type StatementResult struct {
93	Type         StatementResultType
94	Err          error
95	ResultSet    *spannerpb.ResultSet
96	UpdateCount  int64
97	ResumeTokens [][]byte
98}
99
100// PartialResultSetExecutionTime represents execution times and errors that
101// should be used when a PartialResult at the specified resume token is to
102// be returned.
103type PartialResultSetExecutionTime struct {
104	ResumeToken   []byte
105	ExecutionTime time.Duration
106	Err           error
107}
108
109// ToPartialResultSets converts a ResultSet to a PartialResultSet. This method
110// is used to convert a mocked result to a PartialResultSet when one of the
111// streaming methods are called.
112func (s *StatementResult) ToPartialResultSets(resumeToken []byte) (result []*spannerpb.PartialResultSet, err error) {
113	var startIndex uint64
114	if len(resumeToken) > 0 {
115		if startIndex, err = DecodeResumeToken(resumeToken); err != nil {
116			return nil, err
117		}
118	}
119
120	totalRows := uint64(len(s.ResultSet.Rows))
121	if totalRows > 0 {
122		for {
123			rowCount := min(totalRows-startIndex, uint64(MaxRowsPerPartialResultSet))
124			rows := s.ResultSet.Rows[startIndex : startIndex+rowCount]
125			values := make([]*structpb.Value,
126				len(rows)*len(s.ResultSet.Metadata.RowType.Fields))
127			var idx int
128			for _, row := range rows {
129				for colIdx := range s.ResultSet.Metadata.RowType.Fields {
130					values[idx] = row.Values[colIdx]
131					idx++
132				}
133			}
134			var rt []byte
135			if len(s.ResumeTokens) == 0 {
136				rt = EncodeResumeToken(startIndex + rowCount)
137			} else {
138				rt = s.ResumeTokens[startIndex]
139			}
140			result = append(result, &spannerpb.PartialResultSet{
141				Metadata:    s.ResultSet.Metadata,
142				Values:      values,
143				ResumeToken: rt,
144			})
145
146			startIndex += rowCount
147			if startIndex == totalRows {
148				break
149			}
150		}
151	} else {
152		result = append(result, &spannerpb.PartialResultSet{
153			Metadata: s.ResultSet.Metadata,
154		})
155	}
156	return result, nil
157}
158
159func min(x, y uint64) uint64 {
160	if x > y {
161		return y
162	}
163	return x
164}
165
166func (s *StatementResult) updateCountToPartialResultSet(exact bool) *spannerpb.PartialResultSet {
167	return &spannerpb.PartialResultSet{
168		Stats: s.convertUpdateCountToResultSet(exact).Stats,
169	}
170}
171
172// Converts an update count to a ResultSet, as DML statements also return the
173// update count as the statistics of a ResultSet.
174func (s *StatementResult) convertUpdateCountToResultSet(exact bool) *spannerpb.ResultSet {
175	if exact {
176		return &spannerpb.ResultSet{
177			Stats: &spannerpb.ResultSetStats{
178				RowCount: &spannerpb.ResultSetStats_RowCountExact{
179					RowCountExact: s.UpdateCount,
180				},
181			},
182		}
183	}
184	return &spannerpb.ResultSet{
185		Stats: &spannerpb.ResultSetStats{
186			RowCount: &spannerpb.ResultSetStats_RowCountLowerBound{
187				RowCountLowerBound: s.UpdateCount,
188			},
189		},
190	}
191}
192
193// SimulatedExecutionTime represents the time the execution of a method
194// should take, and any errors that should be returned by the method.
195type SimulatedExecutionTime struct {
196	MinimumExecutionTime time.Duration
197	RandomExecutionTime  time.Duration
198	Errors               []error
199	// Keep error after execution. The error will continue to be returned until
200	// it is cleared.
201	KeepError bool
202}
203
204// InMemSpannerServer contains the SpannerServer interface plus a couple
205// of specific methods for adding mocked results and resetting the server.
206type InMemSpannerServer interface {
207	spannerpb.SpannerServer
208
209	// Stops this server.
210	Stop()
211
212	// Resets the in-mem server to its default state, deleting all sessions and
213	// transactions that have been created on the server. Mocked results are
214	// not deleted.
215	Reset()
216
217	// Sets an error that will be returned by the next server call. The server
218	// call will also automatically clear the error.
219	SetError(err error)
220
221	// Puts a mocked result on the server for a specific sql statement. The
222	// server does not parse the SQL string in any way, it is merely used as
223	// a key to the mocked result. The result will be used for all methods that
224	// expect a SQL statement, including (batch) DML methods.
225	PutStatementResult(sql string, result *StatementResult) error
226
227	// Puts a mocked result on the server for a specific partition token. The
228	// result will only be used for query requests that specify a partition
229	// token.
230	PutPartitionResult(partitionToken []byte, result *StatementResult) error
231
232	// Adds a PartialResultSetExecutionTime to the server that should be returned
233	// for the specified SQL string.
234	AddPartialResultSetError(sql string, err PartialResultSetExecutionTime)
235
236	// Removes a mocked result on the server for a specific sql statement.
237	RemoveStatementResult(sql string)
238
239	// Aborts the specified transaction . This method can be used to test
240	// transaction retry logic.
241	AbortTransaction(id []byte)
242
243	// Puts a simulated execution time for one of the Spanner methods.
244	PutExecutionTime(method string, executionTime SimulatedExecutionTime)
245	// Freeze stalls all requests.
246	Freeze()
247	// Unfreeze restores processing requests.
248	Unfreeze()
249
250	TotalSessionsCreated() uint
251	TotalSessionsDeleted() uint
252	SetMaxSessionsReturnedByServerPerBatchRequest(sessionCount int32)
253	SetMaxSessionsReturnedByServerInTotal(sessionCount int32)
254
255	ReceivedRequests() chan interface{}
256	DumpSessions() map[string]bool
257	ClearPings()
258	DumpPings() []string
259}
260
261type inMemSpannerServer struct {
262	// Embed for forward compatibility.
263	// Tests will keep working if more methods are added
264	// in the future.
265	spannerpb.SpannerServer
266
267	mu sync.Mutex
268	// Set to true when this server been stopped. This is the end state of a
269	// server, a stopped server cannot be restarted.
270	stopped bool
271	// If set, all calls return this error.
272	err error
273	// The mock server creates session IDs using this counter.
274	sessionCounter uint64
275	// The sessions that have been created on this mock server.
276	sessions map[string]*spannerpb.Session
277	// Last use times per session.
278	sessionLastUseTime map[string]time.Time
279	// The mock server creates transaction IDs per session using these
280	// counters.
281	transactionCounters map[string]*uint64
282	// The transactions that have been created on this mock server.
283	transactions map[string]*spannerpb.Transaction
284	// The transactions that have been (manually) aborted on the server.
285	abortedTransactions map[string]bool
286	// The transactions that are marked as PartitionedDMLTransaction
287	partitionedDmlTransactions map[string]bool
288	// The mocked results for this server.
289	statementResults map[string]*StatementResult
290	partitionResults map[string]*StatementResult
291	// The simulated execution times per method.
292	executionTimes map[string]*SimulatedExecutionTime
293	// The simulated errors for partial result sets
294	partialResultSetErrors map[string][]*PartialResultSetExecutionTime
295
296	totalSessionsCreated uint
297	totalSessionsDeleted uint
298	// The maximum number of sessions that will be created per batch request.
299	maxSessionsReturnedByServerPerBatchRequest int32
300	maxSessionsReturnedByServerInTotal         int32
301	receivedRequests                           chan interface{}
302	// Session ping history.
303	pings []string
304
305	// Server will stall on any requests.
306	freezed chan struct{}
307}
308
309// NewInMemSpannerServer creates a new in-mem test server.
310func NewInMemSpannerServer() InMemSpannerServer {
311	res := &inMemSpannerServer{}
312	res.initDefaults()
313	res.statementResults = make(map[string]*StatementResult)
314	res.partitionResults = make(map[string]*StatementResult)
315	res.executionTimes = make(map[string]*SimulatedExecutionTime)
316	res.partialResultSetErrors = make(map[string][]*PartialResultSetExecutionTime)
317	res.receivedRequests = make(chan interface{}, 1000000)
318	// Produce a closed channel, so the default action of ready is to not block.
319	res.Freeze()
320	res.Unfreeze()
321	return res
322}
323
324func (s *inMemSpannerServer) Stop() {
325	s.mu.Lock()
326	defer s.mu.Unlock()
327	s.stopped = true
328	close(s.receivedRequests)
329}
330
331// Resets the test server to its initial state, deleting all sessions and
332// transactions that have been created on the server. This method will not
333// remove mocked results.
334func (s *inMemSpannerServer) Reset() {
335	s.mu.Lock()
336	defer s.mu.Unlock()
337	close(s.receivedRequests)
338	s.receivedRequests = make(chan interface{}, 1000000)
339	s.initDefaults()
340}
341
342func (s *inMemSpannerServer) SetError(err error) {
343	s.mu.Lock()
344	defer s.mu.Unlock()
345	s.err = err
346}
347
348// Registers a mocked result for a SQL statement on the server.
349func (s *inMemSpannerServer) PutStatementResult(sql string, result *StatementResult) error {
350	s.mu.Lock()
351	defer s.mu.Unlock()
352	s.statementResults[sql] = result
353	return nil
354}
355
356func (s *inMemSpannerServer) RemoveStatementResult(sql string) {
357	s.mu.Lock()
358	defer s.mu.Unlock()
359	delete(s.statementResults, sql)
360}
361
362// Registers a mocked result for a partition token on the server.
363func (s *inMemSpannerServer) PutPartitionResult(partitionToken []byte, result *StatementResult) error {
364	tokenString := string(partitionToken)
365	s.mu.Lock()
366	defer s.mu.Unlock()
367	s.partitionResults[tokenString] = result
368	return nil
369}
370
371func (s *inMemSpannerServer) AbortTransaction(id []byte) {
372	s.mu.Lock()
373	defer s.mu.Unlock()
374	s.abortedTransactions[string(id)] = true
375}
376
377func (s *inMemSpannerServer) PutExecutionTime(method string, executionTime SimulatedExecutionTime) {
378	s.mu.Lock()
379	defer s.mu.Unlock()
380	s.executionTimes[method] = &executionTime
381}
382
383func (s *inMemSpannerServer) AddPartialResultSetError(sql string, partialResultSetError PartialResultSetExecutionTime) {
384	s.mu.Lock()
385	defer s.mu.Unlock()
386	s.partialResultSetErrors[sql] = append(s.partialResultSetErrors[sql], &partialResultSetError)
387}
388
389// Freeze stalls all requests.
390func (s *inMemSpannerServer) Freeze() {
391	s.mu.Lock()
392	defer s.mu.Unlock()
393	s.freezed = make(chan struct{})
394}
395
396// Unfreeze restores processing requests.
397func (s *inMemSpannerServer) Unfreeze() {
398	s.mu.Lock()
399	defer s.mu.Unlock()
400	close(s.freezed)
401}
402
403// ready checks conditions before executing requests
404func (s *inMemSpannerServer) ready() {
405	s.mu.Lock()
406	freezed := s.freezed
407	s.mu.Unlock()
408	// check if server should be freezed
409	<-freezed
410}
411
412func (s *inMemSpannerServer) TotalSessionsCreated() uint {
413	s.mu.Lock()
414	defer s.mu.Unlock()
415	return s.totalSessionsCreated
416}
417
418func (s *inMemSpannerServer) TotalSessionsDeleted() uint {
419	s.mu.Lock()
420	defer s.mu.Unlock()
421	return s.totalSessionsDeleted
422}
423
424func (s *inMemSpannerServer) SetMaxSessionsReturnedByServerPerBatchRequest(sessionCount int32) {
425	s.mu.Lock()
426	defer s.mu.Unlock()
427	s.maxSessionsReturnedByServerPerBatchRequest = sessionCount
428}
429
430func (s *inMemSpannerServer) SetMaxSessionsReturnedByServerInTotal(sessionCount int32) {
431	s.mu.Lock()
432	defer s.mu.Unlock()
433	s.maxSessionsReturnedByServerInTotal = sessionCount
434}
435
436func (s *inMemSpannerServer) ReceivedRequests() chan interface{} {
437	return s.receivedRequests
438}
439
440// ClearPings clears the ping history from the server.
441func (s *inMemSpannerServer) ClearPings() {
442	s.mu.Lock()
443	defer s.mu.Unlock()
444	s.pings = nil
445}
446
447// DumpPings dumps the ping history.
448func (s *inMemSpannerServer) DumpPings() []string {
449	s.mu.Lock()
450	defer s.mu.Unlock()
451	return append([]string(nil), s.pings...)
452}
453
454// DumpSessions dumps the internal session table.
455func (s *inMemSpannerServer) DumpSessions() map[string]bool {
456	s.mu.Lock()
457	defer s.mu.Unlock()
458	st := map[string]bool{}
459	for s := range s.sessions {
460		st[s] = true
461	}
462	return st
463}
464
465func (s *inMemSpannerServer) initDefaults() {
466	s.sessionCounter = 0
467	s.maxSessionsReturnedByServerPerBatchRequest = 100
468	s.sessions = make(map[string]*spannerpb.Session)
469	s.sessionLastUseTime = make(map[string]time.Time)
470	s.transactions = make(map[string]*spannerpb.Transaction)
471	s.abortedTransactions = make(map[string]bool)
472	s.partitionedDmlTransactions = make(map[string]bool)
473	s.transactionCounters = make(map[string]*uint64)
474}
475
476func (s *inMemSpannerServer) generateSessionNameLocked(database string) string {
477	s.sessionCounter++
478	return fmt.Sprintf("%s/sessions/%d", database, s.sessionCounter)
479}
480
481func (s *inMemSpannerServer) findSession(name string) (*spannerpb.Session, error) {
482	s.mu.Lock()
483	defer s.mu.Unlock()
484	session := s.sessions[name]
485	if session == nil {
486		return nil, newSessionNotFoundError(name)
487	}
488	return session, nil
489}
490
491// sessionResourceType is the type name of Spanner sessions.
492const sessionResourceType = "type.googleapis.com/google.spanner.v1.Session"
493
494func newSessionNotFoundError(name string) error {
495	s := gstatus.Newf(codes.NotFound, "Session not found: Session with id %s not found", name)
496	s, _ = s.WithDetails(&errdetails.ResourceInfo{ResourceType: sessionResourceType, ResourceName: name})
497	return s.Err()
498}
499
500func (s *inMemSpannerServer) updateSessionLastUseTime(session string) {
501	s.mu.Lock()
502	defer s.mu.Unlock()
503	s.sessionLastUseTime[session] = time.Now()
504}
505
506func getCurrentTimestamp() *timestamp.Timestamp {
507	t := time.Now()
508	return &timestamp.Timestamp{Seconds: t.Unix(), Nanos: int32(t.Nanosecond())}
509}
510
511// Gets the transaction id from the transaction selector. If the selector
512// specifies that a new transaction should be started, this method will start
513// a new transaction and return the id of that transaction.
514func (s *inMemSpannerServer) getTransactionID(session *spannerpb.Session, txSelector *spannerpb.TransactionSelector) []byte {
515	var res []byte
516	if txSelector.GetBegin() != nil {
517		// Start a new transaction.
518		res = s.beginTransaction(session, txSelector.GetBegin()).Id
519	} else if txSelector.GetId() != nil {
520		res = txSelector.GetId()
521	}
522	return res
523}
524
525func (s *inMemSpannerServer) generateTransactionName(session string) string {
526	s.mu.Lock()
527	defer s.mu.Unlock()
528	counter, ok := s.transactionCounters[session]
529	if !ok {
530		counter = new(uint64)
531		s.transactionCounters[session] = counter
532	}
533	*counter++
534	return fmt.Sprintf("%s/transactions/%d", session, *counter)
535}
536
537func (s *inMemSpannerServer) beginTransaction(session *spannerpb.Session, options *spannerpb.TransactionOptions) *spannerpb.Transaction {
538	id := s.generateTransactionName(session.Name)
539	res := &spannerpb.Transaction{
540		Id:            []byte(id),
541		ReadTimestamp: getCurrentTimestamp(),
542	}
543	s.mu.Lock()
544	s.transactions[id] = res
545	s.partitionedDmlTransactions[id] = options.GetPartitionedDml() != nil
546	s.mu.Unlock()
547	return res
548}
549
550func (s *inMemSpannerServer) getTransactionByID(id []byte) (*spannerpb.Transaction, error) {
551	s.mu.Lock()
552	defer s.mu.Unlock()
553	tx, ok := s.transactions[string(id)]
554	if !ok {
555		return nil, gstatus.Error(codes.NotFound, "Transaction not found")
556	}
557	aborted, ok := s.abortedTransactions[string(id)]
558	if ok && aborted {
559		return nil, newAbortedErrorWithMinimalRetryDelay()
560	}
561	return tx, nil
562}
563
564func newAbortedErrorWithMinimalRetryDelay() error {
565	st := gstatus.New(codes.Aborted, "Transaction has been aborted")
566	retry := &errdetails.RetryInfo{
567		RetryDelay: ptypes.DurationProto(time.Nanosecond),
568	}
569	st, _ = st.WithDetails(retry)
570	return st.Err()
571}
572
573func (s *inMemSpannerServer) removeTransaction(tx *spannerpb.Transaction) {
574	s.mu.Lock()
575	defer s.mu.Unlock()
576	delete(s.transactions, string(tx.Id))
577	delete(s.partitionedDmlTransactions, string(tx.Id))
578}
579
580func (s *inMemSpannerServer) getPartitionResult(partitionToken []byte) (*StatementResult, error) {
581	tokenString := string(partitionToken)
582	s.mu.Lock()
583	defer s.mu.Unlock()
584	result, ok := s.partitionResults[tokenString]
585	if !ok {
586		return nil, gstatus.Error(codes.Internal, fmt.Sprintf("No result found for partition token %v", tokenString))
587	}
588	return result, nil
589}
590
591func (s *inMemSpannerServer) getStatementResult(sql string) (*StatementResult, error) {
592	s.mu.Lock()
593	defer s.mu.Unlock()
594	result, ok := s.statementResults[sql]
595	if !ok {
596		return nil, gstatus.Error(codes.Internal, fmt.Sprintf("No result found for statement %v", sql))
597	}
598	return result, nil
599}
600
601func (s *inMemSpannerServer) simulateExecutionTime(method string, req interface{}) error {
602	s.mu.Lock()
603	if s.stopped {
604		s.mu.Unlock()
605		return gstatus.Error(codes.Unavailable, "server has been stopped")
606	}
607	s.receivedRequests <- req
608	s.mu.Unlock()
609	s.ready()
610	s.mu.Lock()
611	if s.err != nil {
612		err := s.err
613		s.err = nil
614		s.mu.Unlock()
615		return err
616	}
617	executionTime, ok := s.executionTimes[method]
618	s.mu.Unlock()
619	if ok {
620		var randTime int64
621		if executionTime.RandomExecutionTime > 0 {
622			randTime = rand.Int63n(int64(executionTime.RandomExecutionTime))
623		}
624		totalExecutionTime := time.Duration(int64(executionTime.MinimumExecutionTime) + randTime)
625		<-time.After(totalExecutionTime)
626		s.mu.Lock()
627		if executionTime.Errors != nil && len(executionTime.Errors) > 0 {
628			err := executionTime.Errors[0]
629			if !executionTime.KeepError {
630				executionTime.Errors = executionTime.Errors[1:]
631			}
632			s.mu.Unlock()
633			return err
634		}
635		s.mu.Unlock()
636	}
637	return nil
638}
639
640func (s *inMemSpannerServer) CreateSession(ctx context.Context, req *spannerpb.CreateSessionRequest) (*spannerpb.Session, error) {
641	if err := s.simulateExecutionTime(MethodCreateSession, req); err != nil {
642		return nil, err
643	}
644	if req.Database == "" {
645		return nil, gstatus.Error(codes.InvalidArgument, "Missing database")
646	}
647	s.mu.Lock()
648	defer s.mu.Unlock()
649	if s.maxSessionsReturnedByServerInTotal > int32(0) && int32(len(s.sessions)) == s.maxSessionsReturnedByServerInTotal {
650		return nil, gstatus.Error(codes.ResourceExhausted, "No more sessions available")
651	}
652	sessionName := s.generateSessionNameLocked(req.Database)
653	ts := getCurrentTimestamp()
654	session := &spannerpb.Session{Name: sessionName, CreateTime: ts, ApproximateLastUseTime: ts}
655	s.totalSessionsCreated++
656	s.sessions[sessionName] = session
657	return session, nil
658}
659
660func (s *inMemSpannerServer) BatchCreateSessions(ctx context.Context, req *spannerpb.BatchCreateSessionsRequest) (*spannerpb.BatchCreateSessionsResponse, error) {
661	if err := s.simulateExecutionTime(MethodBatchCreateSession, req); err != nil {
662		return nil, err
663	}
664	if req.Database == "" {
665		return nil, gstatus.Error(codes.InvalidArgument, "Missing database")
666	}
667	if req.SessionCount <= 0 {
668		return nil, gstatus.Error(codes.InvalidArgument, "Session count must be >= 0")
669	}
670	sessionsToCreate := req.SessionCount
671	s.mu.Lock()
672	defer s.mu.Unlock()
673	if s.maxSessionsReturnedByServerInTotal > int32(0) && int32(len(s.sessions)) >= s.maxSessionsReturnedByServerInTotal {
674		return nil, gstatus.Error(codes.ResourceExhausted, "No more sessions available")
675	}
676	if sessionsToCreate > s.maxSessionsReturnedByServerPerBatchRequest {
677		sessionsToCreate = s.maxSessionsReturnedByServerPerBatchRequest
678	}
679	if s.maxSessionsReturnedByServerInTotal > int32(0) && (sessionsToCreate+int32(len(s.sessions))) > s.maxSessionsReturnedByServerInTotal {
680		sessionsToCreate = s.maxSessionsReturnedByServerInTotal - int32(len(s.sessions))
681	}
682	sessions := make([]*spannerpb.Session, sessionsToCreate)
683	for i := int32(0); i < sessionsToCreate; i++ {
684		sessionName := s.generateSessionNameLocked(req.Database)
685		ts := getCurrentTimestamp()
686		sessions[i] = &spannerpb.Session{Name: sessionName, CreateTime: ts, ApproximateLastUseTime: ts}
687		s.totalSessionsCreated++
688		s.sessions[sessionName] = sessions[i]
689	}
690	return &spannerpb.BatchCreateSessionsResponse{Session: sessions}, nil
691}
692
693func (s *inMemSpannerServer) GetSession(ctx context.Context, req *spannerpb.GetSessionRequest) (*spannerpb.Session, error) {
694	if err := s.simulateExecutionTime(MethodGetSession, req); err != nil {
695		return nil, err
696	}
697	if req.Name == "" {
698		return nil, gstatus.Error(codes.InvalidArgument, "Missing session name")
699	}
700	session, err := s.findSession(req.Name)
701	if err != nil {
702		return nil, err
703	}
704	return session, nil
705}
706
707func (s *inMemSpannerServer) ListSessions(ctx context.Context, req *spannerpb.ListSessionsRequest) (*spannerpb.ListSessionsResponse, error) {
708	s.mu.Lock()
709	if s.stopped {
710		s.mu.Unlock()
711		return nil, gstatus.Error(codes.Unavailable, "server has been stopped")
712	}
713	s.receivedRequests <- req
714	s.mu.Unlock()
715	if req.Database == "" {
716		return nil, gstatus.Error(codes.InvalidArgument, "Missing database")
717	}
718	expectedSessionName := req.Database + "/sessions/"
719	var sessions []*spannerpb.Session
720	s.mu.Lock()
721	for _, session := range s.sessions {
722		if strings.Index(session.Name, expectedSessionName) == 0 {
723			sessions = append(sessions, session)
724		}
725	}
726	s.mu.Unlock()
727	sort.Slice(sessions[:], func(i, j int) bool {
728		return sessions[i].Name < sessions[j].Name
729	})
730	res := &spannerpb.ListSessionsResponse{Sessions: sessions}
731	return res, nil
732}
733
734func (s *inMemSpannerServer) DeleteSession(ctx context.Context, req *spannerpb.DeleteSessionRequest) (*emptypb.Empty, error) {
735	if err := s.simulateExecutionTime(MethodDeleteSession, req); err != nil {
736		return nil, err
737	}
738	if req.Name == "" {
739		return nil, gstatus.Error(codes.InvalidArgument, "Missing session name")
740	}
741	if _, err := s.findSession(req.Name); err != nil {
742		return nil, err
743	}
744	s.mu.Lock()
745	defer s.mu.Unlock()
746	s.totalSessionsDeleted++
747	delete(s.sessions, req.Name)
748	return &emptypb.Empty{}, nil
749}
750
751func (s *inMemSpannerServer) ExecuteSql(ctx context.Context, req *spannerpb.ExecuteSqlRequest) (*spannerpb.ResultSet, error) {
752	if err := s.simulateExecutionTime(MethodExecuteSql, req); err != nil {
753		return nil, err
754	}
755	if req.Sql == "SELECT 1" {
756		s.mu.Lock()
757		s.pings = append(s.pings, req.Session)
758		s.mu.Unlock()
759	}
760	if req.Session == "" {
761		return nil, gstatus.Error(codes.InvalidArgument, "Missing session name")
762	}
763	session, err := s.findSession(req.Session)
764	if err != nil {
765		return nil, err
766	}
767	var id []byte
768	s.updateSessionLastUseTime(session.Name)
769	if id = s.getTransactionID(session, req.Transaction); id != nil {
770		_, err = s.getTransactionByID(id)
771		if err != nil {
772			return nil, err
773		}
774	}
775	var statementResult *StatementResult
776	if req.PartitionToken != nil {
777		statementResult, err = s.getPartitionResult(req.PartitionToken)
778	} else {
779		statementResult, err = s.getStatementResult(req.Sql)
780	}
781	if err != nil {
782		return nil, err
783	}
784	s.mu.Lock()
785	isPartitionedDml := s.partitionedDmlTransactions[string(id)]
786	s.mu.Unlock()
787	switch statementResult.Type {
788	case StatementResultError:
789		return nil, statementResult.Err
790	case StatementResultResultSet:
791		return statementResult.ResultSet, nil
792	case StatementResultUpdateCount:
793		return statementResult.convertUpdateCountToResultSet(!isPartitionedDml), nil
794	}
795	return nil, gstatus.Error(codes.Internal, "Unknown result type")
796}
797
798func (s *inMemSpannerServer) ExecuteStreamingSql(req *spannerpb.ExecuteSqlRequest, stream spannerpb.Spanner_ExecuteStreamingSqlServer) error {
799	if err := s.simulateExecutionTime(MethodExecuteStreamingSql, req); err != nil {
800		return err
801	}
802	return s.executeStreamingSQL(req, stream)
803}
804
805func (s *inMemSpannerServer) executeStreamingSQL(req *spannerpb.ExecuteSqlRequest, stream spannerpb.Spanner_ExecuteStreamingSqlServer) error {
806	if req.Session == "" {
807		return gstatus.Error(codes.InvalidArgument, "Missing session name")
808	}
809	session, err := s.findSession(req.Session)
810	if err != nil {
811		return err
812	}
813	s.updateSessionLastUseTime(session.Name)
814	var id []byte
815	if id = s.getTransactionID(session, req.Transaction); id != nil {
816		_, err = s.getTransactionByID(id)
817		if err != nil {
818			return err
819		}
820	}
821	var statementResult *StatementResult
822	if req.PartitionToken != nil {
823		statementResult, err = s.getPartitionResult(req.PartitionToken)
824	} else {
825		statementResult, err = s.getStatementResult(req.Sql)
826	}
827	if err != nil {
828		return err
829	}
830	s.mu.Lock()
831	isPartitionedDml := s.partitionedDmlTransactions[string(id)]
832	s.mu.Unlock()
833	switch statementResult.Type {
834	case StatementResultError:
835		return statementResult.Err
836	case StatementResultResultSet:
837		parts, err := statementResult.ToPartialResultSets(req.ResumeToken)
838		if err != nil {
839			return err
840		}
841		var nextPartialResultSetError *PartialResultSetExecutionTime
842		s.mu.Lock()
843		pErrors := s.partialResultSetErrors[req.Sql]
844		if len(pErrors) > 0 {
845			nextPartialResultSetError = pErrors[0]
846			s.partialResultSetErrors[req.Sql] = pErrors[1:]
847		}
848		s.mu.Unlock()
849		for _, part := range parts {
850			if nextPartialResultSetError != nil && bytes.Equal(part.ResumeToken, nextPartialResultSetError.ResumeToken) {
851				if nextPartialResultSetError.ExecutionTime > 0 {
852					<-time.After(nextPartialResultSetError.ExecutionTime)
853				}
854				if nextPartialResultSetError.Err != nil {
855					return nextPartialResultSetError.Err
856				}
857			}
858			if err := stream.Send(part); err != nil {
859				return err
860			}
861		}
862		return nil
863	case StatementResultUpdateCount:
864		part := statementResult.updateCountToPartialResultSet(!isPartitionedDml)
865		if err := stream.Send(part); err != nil {
866			return err
867		}
868		return nil
869	}
870	return gstatus.Error(codes.Internal, "Unknown result type")
871}
872
873func (s *inMemSpannerServer) ExecuteBatchDml(ctx context.Context, req *spannerpb.ExecuteBatchDmlRequest) (*spannerpb.ExecuteBatchDmlResponse, error) {
874	if err := s.simulateExecutionTime(MethodExecuteBatchDml, req); err != nil {
875		return nil, err
876	}
877	if req.Session == "" {
878		return nil, gstatus.Error(codes.InvalidArgument, "Missing session name")
879	}
880	session, err := s.findSession(req.Session)
881	if err != nil {
882		return nil, err
883	}
884	s.updateSessionLastUseTime(session.Name)
885	var id []byte
886	if id = s.getTransactionID(session, req.Transaction); id != nil {
887		_, err = s.getTransactionByID(id)
888		if err != nil {
889			return nil, err
890		}
891	}
892	s.mu.Lock()
893	isPartitionedDml := s.partitionedDmlTransactions[string(id)]
894	s.mu.Unlock()
895	resp := &spannerpb.ExecuteBatchDmlResponse{}
896	resp.ResultSets = make([]*spannerpb.ResultSet, len(req.Statements))
897	resp.Status = &status.Status{Code: int32(codes.OK)}
898	for idx, batchStatement := range req.Statements {
899		statementResult, err := s.getStatementResult(batchStatement.Sql)
900		if err != nil {
901			return nil, err
902		}
903		switch statementResult.Type {
904		case StatementResultError:
905			resp.Status = &status.Status{Code: int32(gstatus.Code(statementResult.Err)), Message: statementResult.Err.Error()}
906			resp.ResultSets = resp.ResultSets[:idx]
907			return resp, nil
908		case StatementResultResultSet:
909			return nil, gstatus.Error(codes.InvalidArgument, fmt.Sprintf("Not an update statement: %v", batchStatement.Sql))
910		case StatementResultUpdateCount:
911			resp.ResultSets[idx] = statementResult.convertUpdateCountToResultSet(!isPartitionedDml)
912		}
913	}
914	return resp, nil
915}
916
917func (s *inMemSpannerServer) Read(ctx context.Context, req *spannerpb.ReadRequest) (*spannerpb.ResultSet, error) {
918	s.mu.Lock()
919	if s.stopped {
920		s.mu.Unlock()
921		return nil, gstatus.Error(codes.Unavailable, "server has been stopped")
922	}
923	s.receivedRequests <- req
924	s.mu.Unlock()
925	return nil, gstatus.Error(codes.Unimplemented, "Method not yet implemented")
926}
927
928func (s *inMemSpannerServer) StreamingRead(req *spannerpb.ReadRequest, stream spannerpb.Spanner_StreamingReadServer) error {
929	if err := s.simulateExecutionTime(MethodStreamingRead, req); err != nil {
930		return err
931	}
932	sqlReq := &spannerpb.ExecuteSqlRequest{
933		Session:        req.Session,
934		Transaction:    req.Transaction,
935		PartitionToken: req.PartitionToken,
936		ResumeToken:    req.ResumeToken,
937		// KeySet is currently ignored.
938		Sql: fmt.Sprintf(
939			"SELECT %s FROM %s",
940			strings.Join(req.Columns, ", "),
941			req.Table,
942		),
943	}
944	return s.executeStreamingSQL(sqlReq, stream)
945}
946
947func (s *inMemSpannerServer) BeginTransaction(ctx context.Context, req *spannerpb.BeginTransactionRequest) (*spannerpb.Transaction, error) {
948	if err := s.simulateExecutionTime(MethodBeginTransaction, req); err != nil {
949		return nil, err
950	}
951	if req.Session == "" {
952		return nil, gstatus.Error(codes.InvalidArgument, "Missing session name")
953	}
954	session, err := s.findSession(req.Session)
955	if err != nil {
956		return nil, err
957	}
958	s.updateSessionLastUseTime(session.Name)
959	tx := s.beginTransaction(session, req.Options)
960	return tx, nil
961}
962
963func (s *inMemSpannerServer) Commit(ctx context.Context, req *spannerpb.CommitRequest) (*spannerpb.CommitResponse, error) {
964	if err := s.simulateExecutionTime(MethodCommitTransaction, req); err != nil {
965		return nil, err
966	}
967	if req.Session == "" {
968		return nil, gstatus.Error(codes.InvalidArgument, "Missing session name")
969	}
970	session, err := s.findSession(req.Session)
971	if err != nil {
972		return nil, err
973	}
974	s.updateSessionLastUseTime(session.Name)
975	var tx *spannerpb.Transaction
976	if req.GetSingleUseTransaction() != nil {
977		tx = s.beginTransaction(session, req.GetSingleUseTransaction())
978	} else if req.GetTransactionId() != nil {
979		tx, err = s.getTransactionByID(req.GetTransactionId())
980		if err != nil {
981			return nil, err
982		}
983	} else {
984		return nil, gstatus.Error(codes.InvalidArgument, "Missing transaction in commit request")
985	}
986	s.removeTransaction(tx)
987	resp := &spannerpb.CommitResponse{CommitTimestamp: getCurrentTimestamp()}
988	if req.ReturnCommitStats {
989		resp.CommitStats = &spannerpb.CommitResponse_CommitStats{
990			MutationCount: int64(1),
991		}
992	}
993	return resp, nil
994}
995
996func (s *inMemSpannerServer) Rollback(ctx context.Context, req *spannerpb.RollbackRequest) (*emptypb.Empty, error) {
997	s.mu.Lock()
998	if s.stopped {
999		s.mu.Unlock()
1000		return nil, gstatus.Error(codes.Unavailable, "server has been stopped")
1001	}
1002	s.receivedRequests <- req
1003	s.mu.Unlock()
1004	if req.Session == "" {
1005		return nil, gstatus.Error(codes.InvalidArgument, "Missing session name")
1006	}
1007	session, err := s.findSession(req.Session)
1008	if err != nil {
1009		return nil, err
1010	}
1011	s.updateSessionLastUseTime(session.Name)
1012	tx, err := s.getTransactionByID(req.TransactionId)
1013	if err != nil {
1014		return nil, err
1015	}
1016	s.removeTransaction(tx)
1017	return &emptypb.Empty{}, nil
1018}
1019
1020func (s *inMemSpannerServer) PartitionQuery(ctx context.Context, req *spannerpb.PartitionQueryRequest) (*spannerpb.PartitionResponse, error) {
1021	s.mu.Lock()
1022	if s.stopped {
1023		s.mu.Unlock()
1024		return nil, gstatus.Error(codes.Unavailable, "server has been stopped")
1025	}
1026	s.receivedRequests <- req
1027	s.mu.Unlock()
1028	if req.Session == "" {
1029		return nil, gstatus.Error(codes.InvalidArgument, "Missing session name")
1030	}
1031	session, err := s.findSession(req.Session)
1032	if err != nil {
1033		return nil, err
1034	}
1035	var id []byte
1036	var tx *spannerpb.Transaction
1037	s.updateSessionLastUseTime(session.Name)
1038	if id = s.getTransactionID(session, req.Transaction); id != nil {
1039		tx, err = s.getTransactionByID(id)
1040		if err != nil {
1041			return nil, err
1042		}
1043	}
1044	var partitions []*spannerpb.Partition
1045	for i := int64(0); i < req.PartitionOptions.MaxPartitions; i++ {
1046		token := make([]byte, 10)
1047		_, err := rand.Read(token)
1048		if err != nil {
1049			return nil, gstatus.Error(codes.Internal, "failed to generate random partition token")
1050		}
1051		partitions = append(partitions, &spannerpb.Partition{PartitionToken: token})
1052	}
1053	return &spannerpb.PartitionResponse{
1054		Partitions:  partitions,
1055		Transaction: tx,
1056	}, nil
1057}
1058
1059func (s *inMemSpannerServer) PartitionRead(ctx context.Context, req *spannerpb.PartitionReadRequest) (*spannerpb.PartitionResponse, error) {
1060	s.mu.Lock()
1061	if s.stopped {
1062		s.mu.Unlock()
1063		return nil, gstatus.Error(codes.Unavailable, "server has been stopped")
1064	}
1065	s.receivedRequests <- req
1066	s.mu.Unlock()
1067	return nil, gstatus.Error(codes.Unimplemented, "Method not yet implemented")
1068}
1069
1070// EncodeResumeToken return mock resume token encoding for an uint64 integer.
1071func EncodeResumeToken(t uint64) []byte {
1072	rt := make([]byte, 16)
1073	binary.PutUvarint(rt, t)
1074	return rt
1075}
1076
1077// DecodeResumeToken decodes a mock resume token into an uint64 integer.
1078func DecodeResumeToken(t []byte) (uint64, error) {
1079	s, n := binary.Uvarint(t)
1080	if n <= 0 {
1081		return 0, fmt.Errorf("invalid resume token: %v", t)
1082	}
1083	return s, nil
1084}
1085