1/*
2Copyright 2019 Google LLC
3
4Licensed under the Apache License, Version 2.0 (the "License");
5you may not use this file except in compliance with the License.
6You may obtain a copy of the License at
7
8    http://www.apache.org/licenses/LICENSE-2.0
9
10Unless required by applicable law or agreed to in writing, software
11distributed under the License is distributed on an "AS IS" BASIS,
12WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13See the License for the specific language governing permissions and
14limitations under the License.
15*/
16
17/*
18Package spannertest contains test helpers for working with Cloud Spanner.
19
20This package is EXPERIMENTAL, and is lacking several features. See the README.md
21file in this directory for more details.
22
23In-memory fake
24
25This package has an in-memory fake implementation of spanner. To use it,
26create a Server, and then connect to it with no security:
27	srv, err := spannertest.NewServer("localhost:0")
28	...
29	conn, err := grpc.DialContext(ctx, srv.Addr, grpc.WithInsecure())
30	...
31	client, err := spanner.NewClient(ctx, db, option.WithGRPCConn(conn))
32	...
33
34Alternatively, create a Server, then set the SPANNER_EMULATOR_HOST environment
35variable and use the regular spanner.NewClient:
36	srv, err := spannertest.NewServer("localhost:0")
37	...
38	os.Setenv("SPANNER_EMULATOR_HOST", srv.Addr)
39	client, err := spanner.NewClient(ctx, db)
40	...
41
42The same server also supports database admin operations for use with
43the cloud.google.com/go/spanner/admin/database/apiv1 package. This only
44simulates the existence of a single database; its name is ignored.
45*/
46package spannertest
47
48import (
49	"context"
50	"encoding/base64"
51	"fmt"
52	"io"
53	"log"
54	"math/rand"
55	"net"
56	"strconv"
57	"sync"
58	"sync/atomic"
59	"time"
60
61	"github.com/golang/protobuf/proto"
62	"github.com/golang/protobuf/ptypes"
63	"google.golang.org/grpc"
64	"google.golang.org/grpc/codes"
65	"google.golang.org/grpc/status"
66
67	anypb "github.com/golang/protobuf/ptypes/any"
68	emptypb "github.com/golang/protobuf/ptypes/empty"
69	structpb "github.com/golang/protobuf/ptypes/struct"
70	timestamppb "github.com/golang/protobuf/ptypes/timestamp"
71	lropb "google.golang.org/genproto/googleapis/longrunning"
72	adminpb "google.golang.org/genproto/googleapis/spanner/admin/database/v1"
73	spannerpb "google.golang.org/genproto/googleapis/spanner/v1"
74
75	"cloud.google.com/go/civil"
76	"cloud.google.com/go/spanner/spansql"
77)
78
79// Server is an in-memory Cloud Spanner fake.
80// It is unauthenticated, non-performant, and only a rough approximation.
81type Server struct {
82	Addr string
83
84	l   net.Listener
85	srv *grpc.Server
86	s   *server
87}
88
89// server is the real implementation of the fake.
90// It is a separate and unexported type so the API won't be cluttered with
91// methods that are only relevant to the fake's implementation.
92type server struct {
93	logf Logger
94
95	db    database
96	start time.Time
97
98	mu       sync.Mutex
99	sessions map[string]*session
100	lros     map[string]*lro
101
102	// Any unimplemented methods will cause a panic.
103	// TODO: Switch to Unimplemented at some point? spannerpb would need regenerating.
104	adminpb.DatabaseAdminServer
105	spannerpb.SpannerServer
106	lropb.OperationsServer
107}
108
109type session struct {
110	name     string
111	creation time.Time
112
113	// This context tracks the lifetime of this session.
114	// It is canceled in DeleteSession.
115	ctx    context.Context
116	cancel func()
117
118	mu           sync.Mutex
119	lastUse      time.Time
120	transactions map[string]*transaction
121}
122
123func (s *session) Proto() *spannerpb.Session {
124	s.mu.Lock()
125	defer s.mu.Unlock()
126	m := &spannerpb.Session{
127		Name:                   s.name,
128		CreateTime:             timestampProto(s.creation),
129		ApproximateLastUseTime: timestampProto(s.lastUse),
130	}
131	return m
132}
133
134// timestampProto returns a valid timestamp.Timestamp,
135// or nil if the given time is zero or isn't representable.
136func timestampProto(t time.Time) *timestamppb.Timestamp {
137	if t.IsZero() {
138		return nil
139	}
140	ts, err := ptypes.TimestampProto(t)
141	if err != nil {
142		return nil
143	}
144	return ts
145}
146
147// lro represents a Long-Running Operation, generally a schema change.
148type lro struct {
149	mu    sync.Mutex
150	state *lropb.Operation
151
152	// waitc is closed when anyone starts waiting on the LRO.
153	// waitatom is CAS'd from 0 to 1 to make that closing safe.
154	waitc    chan struct{}
155	waitatom int32
156}
157
158func newLRO(initState *lropb.Operation) *lro {
159	return &lro{
160		state: initState,
161		waitc: make(chan struct{}),
162	}
163}
164
165func (l *lro) noWait() {
166	if atomic.CompareAndSwapInt32(&l.waitatom, 0, 1) {
167		close(l.waitc)
168	}
169}
170
171func (l *lro) State() *lropb.Operation {
172	l.mu.Lock()
173	defer l.mu.Unlock()
174	return proto.Clone(l.state).(*lropb.Operation)
175}
176
177// Logger is something that can be used for logging.
178// It is matched by log.Printf and testing.T.Logf.
179type Logger func(format string, args ...interface{})
180
181// NewServer creates a new Server.
182// The Server will be listening for gRPC connections, without TLS, on the provided TCP address.
183// The resolved address is available in the Addr field.
184func NewServer(laddr string) (*Server, error) {
185	l, err := net.Listen("tcp", laddr)
186	if err != nil {
187		return nil, err
188	}
189
190	s := &Server{
191		Addr: l.Addr().String(),
192		l:    l,
193		srv:  grpc.NewServer(),
194		s: &server{
195			logf: func(format string, args ...interface{}) {
196				log.Printf("spannertest.inmem: "+format, args...)
197			},
198			start:    time.Now(),
199			sessions: make(map[string]*session),
200			lros:     make(map[string]*lro),
201		},
202	}
203	adminpb.RegisterDatabaseAdminServer(s.srv, s.s)
204	spannerpb.RegisterSpannerServer(s.srv, s.s)
205	lropb.RegisterOperationsServer(s.srv, s.s)
206
207	go s.srv.Serve(s.l)
208
209	return s, nil
210}
211
212// SetLogger sets a logger for the server.
213// You can use a *testing.T as this argument to collate extra information
214// from the execution of the server.
215func (s *Server) SetLogger(l Logger) { s.s.logf = l }
216
217// Close shuts down the server.
218func (s *Server) Close() {
219	s.srv.Stop()
220	s.l.Close()
221}
222
223func genRandomSession() string {
224	var b [4]byte
225	rand.Read(b[:])
226	return fmt.Sprintf("%x", b)
227}
228
229func genRandomTransaction() string {
230	var b [6]byte
231	rand.Read(b[:])
232	return fmt.Sprintf("tx-%x", b)
233}
234
235func genRandomOperation() string {
236	var b [3]byte
237	rand.Read(b[:])
238	return fmt.Sprintf("op-%x", b)
239}
240
241func (s *server) GetOperation(ctx context.Context, req *lropb.GetOperationRequest) (*lropb.Operation, error) {
242	s.mu.Lock()
243	lro, ok := s.lros[req.Name]
244	s.mu.Unlock()
245	if !ok {
246		return nil, status.Errorf(codes.NotFound, "unknown LRO %q", req.Name)
247	}
248
249	// Someone is waiting on this LRO. Disable sleeping in its Run method.
250	lro.noWait()
251
252	return lro.State(), nil
253}
254
255func (s *server) GetDatabase(ctx context.Context, req *adminpb.GetDatabaseRequest) (*adminpb.Database, error) {
256	s.logf("GetDatabase(%q)", req.Name)
257
258	return &adminpb.Database{
259		Name:       req.Name,
260		State:      adminpb.Database_READY,
261		CreateTime: timestampProto(s.start),
262	}, nil
263}
264
265// UpdateDDL applies the given DDL to the server.
266//
267// This is a convenience method for tests that may assume an existing schema.
268// The more general approach is to dial this server using an admin client, and
269// use the UpdateDatabaseDdl RPC method.
270func (s *Server) UpdateDDL(ddl *spansql.DDL) error {
271	ctx := context.Background()
272	for _, stmt := range ddl.List {
273		if st := s.s.runOneDDL(ctx, stmt); st.Code() != codes.OK {
274			return st.Err()
275		}
276	}
277	return nil
278}
279
280func (s *server) UpdateDatabaseDdl(ctx context.Context, req *adminpb.UpdateDatabaseDdlRequest) (*lropb.Operation, error) {
281	// Parse all the DDL statements first.
282	var stmts []spansql.DDLStmt
283	for _, s := range req.Statements {
284		stmt, err := spansql.ParseDDLStmt(s)
285		if err != nil {
286			// TODO: check what code the real Spanner returns here.
287			return nil, status.Errorf(codes.InvalidArgument, "bad DDL statement %q: %v", s, err)
288		}
289		stmts = append(stmts, stmt)
290	}
291
292	// Nothing should be depending on the exact structure of this,
293	// but it is specified in google/spanner/admin/database/v1/spanner_database_admin.proto.
294	id := "projects/fake-proj/instances/fake-instance/databases/fake-db/operations/" + genRandomOperation()
295	lro := newLRO(&lropb.Operation{Name: id})
296	s.mu.Lock()
297	s.lros[id] = lro
298	s.mu.Unlock()
299
300	go lro.Run(s, stmts)
301	return lro.State(), nil
302}
303
304func (l *lro) Run(s *server, stmts []spansql.DDLStmt) {
305	ctx := context.Background()
306
307	for _, stmt := range stmts {
308		// Simulate delayed DDL application, but only if nobody is waiting.
309		select {
310		case <-time.After(100 * time.Millisecond):
311		case <-l.waitc:
312		}
313
314		if st := s.runOneDDL(ctx, stmt); st.Code() != codes.OK {
315			l.mu.Lock()
316			l.state.Done = true
317			l.state.Result = &lropb.Operation_Error{st.Proto()}
318			l.mu.Unlock()
319			return
320		}
321	}
322
323	l.mu.Lock()
324	l.state.Done = true
325	l.state.Result = &lropb.Operation_Response{&anypb.Any{}}
326	l.mu.Unlock()
327}
328
329func (s *server) runOneDDL(ctx context.Context, stmt spansql.DDLStmt) *status.Status {
330	return s.db.ApplyDDL(stmt)
331}
332
333func (s *server) GetDatabaseDdl(ctx context.Context, req *adminpb.GetDatabaseDdlRequest) (*adminpb.GetDatabaseDdlResponse, error) {
334	s.logf("GetDatabaseDdl(%q)", req.Database)
335
336	var resp adminpb.GetDatabaseDdlResponse
337	for _, stmt := range s.db.GetDDL() {
338		resp.Statements = append(resp.Statements, stmt.SQL())
339	}
340	return &resp, nil
341}
342
343func (s *server) CreateSession(ctx context.Context, req *spannerpb.CreateSessionRequest) (*spannerpb.Session, error) {
344	//s.logf("CreateSession(%q)", req.Database)
345	return s.newSession(), nil
346}
347
348func (s *server) newSession() *spannerpb.Session {
349	id := genRandomSession()
350	now := time.Now()
351	sess := &session{
352		name:         id,
353		creation:     now,
354		lastUse:      now,
355		transactions: make(map[string]*transaction),
356	}
357	sess.ctx, sess.cancel = context.WithCancel(context.Background())
358
359	s.mu.Lock()
360	s.sessions[id] = sess
361	s.mu.Unlock()
362
363	return sess.Proto()
364}
365
366func (s *server) BatchCreateSessions(ctx context.Context, req *spannerpb.BatchCreateSessionsRequest) (*spannerpb.BatchCreateSessionsResponse, error) {
367	//s.logf("BatchCreateSessions(%q)", req.Database)
368
369	var sessions []*spannerpb.Session
370	for i := int32(0); i < req.GetSessionCount(); i++ {
371		sessions = append(sessions, s.newSession())
372	}
373
374	return &spannerpb.BatchCreateSessionsResponse{Session: sessions}, nil
375}
376
377func (s *server) GetSession(ctx context.Context, req *spannerpb.GetSessionRequest) (*spannerpb.Session, error) {
378	s.mu.Lock()
379	sess, ok := s.sessions[req.Name]
380	s.mu.Unlock()
381
382	if !ok {
383		// TODO: what error does the real Spanner return?
384		return nil, status.Errorf(codes.NotFound, "unknown session %q", req.Name)
385	}
386
387	return sess.Proto(), nil
388}
389
390// TODO: ListSessions
391
392func (s *server) DeleteSession(ctx context.Context, req *spannerpb.DeleteSessionRequest) (*emptypb.Empty, error) {
393	//s.logf("DeleteSession(%q)", req.Name)
394
395	s.mu.Lock()
396	sess, ok := s.sessions[req.Name]
397	delete(s.sessions, req.Name)
398	s.mu.Unlock()
399
400	if !ok {
401		// TODO: what error does the real Spanner return?
402		return nil, status.Errorf(codes.NotFound, "unknown session %q", req.Name)
403	}
404
405	// Terminate any operations in this session.
406	sess.cancel()
407
408	return &emptypb.Empty{}, nil
409}
410
411// popTx returns an existing transaction, removing it from the session.
412// This is called when a transaction is finishing (Commit, Rollback).
413func (s *server) popTx(sessionID, tid string) (tx *transaction, err error) {
414	s.mu.Lock()
415	sess, ok := s.sessions[sessionID]
416	s.mu.Unlock()
417	if !ok {
418		// TODO: what error does the real Spanner return?
419		return nil, status.Errorf(codes.NotFound, "unknown session %q", sessionID)
420	}
421
422	sess.mu.Lock()
423	sess.lastUse = time.Now()
424	tx, ok = sess.transactions[tid]
425	if ok {
426		delete(sess.transactions, tid)
427	}
428	sess.mu.Unlock()
429	if !ok {
430		// TODO: what error does the real Spanner return?
431		return nil, status.Errorf(codes.NotFound, "unknown transaction ID %q", tid)
432	}
433	return tx, nil
434}
435
436// readTx returns a transaction for the given session and transaction selector.
437// It is used by read/query operations (ExecuteStreamingSql, StreamingRead).
438func (s *server) readTx(ctx context.Context, session string, tsel *spannerpb.TransactionSelector) (tx *transaction, cleanup func(), err error) {
439	s.mu.Lock()
440	sess, ok := s.sessions[session]
441	s.mu.Unlock()
442	if !ok {
443		// TODO: what error does the real Spanner return?
444		return nil, nil, status.Errorf(codes.NotFound, "unknown session %q", session)
445	}
446
447	sess.mu.Lock()
448	sess.lastUse = time.Now()
449	sess.mu.Unlock()
450
451	// Only give a read-only transaction regardless of whether the selector
452	// is requesting a read-write or read-only one, since this is in readTx
453	// and so shouldn't be mutating anyway.
454	singleUse := func() (*transaction, func(), error) {
455		tx := s.db.NewReadOnlyTransaction()
456		return tx, tx.Rollback, nil
457	}
458
459	if tsel.GetSelector() == nil {
460		return singleUse()
461	}
462
463	switch sel := tsel.Selector.(type) {
464	default:
465		return nil, nil, fmt.Errorf("TransactionSelector type %T not supported", sel)
466	case *spannerpb.TransactionSelector_SingleUse:
467		// Ignore options (e.g. timestamps).
468		switch mode := sel.SingleUse.Mode.(type) {
469		case *spannerpb.TransactionOptions_ReadOnly_:
470			return singleUse()
471		case *spannerpb.TransactionOptions_ReadWrite_:
472			return singleUse()
473		default:
474			return nil, nil, fmt.Errorf("single use transaction in mode %T not supported", mode)
475		}
476	case *spannerpb.TransactionSelector_Id:
477		sess.mu.Lock()
478		tx, ok := sess.transactions[string(sel.Id)]
479		sess.mu.Unlock()
480		if !ok {
481			return nil, nil, fmt.Errorf("no transaction with id %q", sel.Id)
482		}
483		return tx, func() {}, nil
484	}
485}
486
487func (s *server) ExecuteSql(ctx context.Context, req *spannerpb.ExecuteSqlRequest) (*spannerpb.ResultSet, error) {
488	// Assume this is probably a DML statement or a ping from the session pool.
489	// Queries normally use ExecuteStreamingSql.
490	// TODO: Expand this to support more things.
491
492	// If it is a single-use transaction we assume it is a query.
493	if req.Transaction.GetSelector() == nil || req.Transaction.GetSingleUse().GetReadOnly() != nil {
494		ri, err := s.executeQuery(req)
495		if err != nil {
496			return nil, err
497		}
498		return s.resultSet(ri)
499	}
500
501	obj, ok := req.Transaction.Selector.(*spannerpb.TransactionSelector_Id)
502	if !ok {
503		return nil, fmt.Errorf("unsupported transaction type %T", req.Transaction.Selector)
504	}
505	tid := string(obj.Id)
506	_ = tid // TODO: lookup an existing transaction by ID.
507
508	stmt, err := spansql.ParseDMLStmt(req.Sql)
509	if err != nil {
510		return nil, status.Errorf(codes.InvalidArgument, "bad DML: %v", err)
511	}
512	params, err := parseQueryParams(req.GetParams(), req.ParamTypes)
513	if err != nil {
514		return nil, err
515	}
516
517	s.logf("Executing: %s", stmt.SQL())
518	if len(params) > 0 {
519		s.logf("        ▹ %v", params)
520	}
521
522	n, err := s.db.Execute(stmt, params)
523	if err != nil {
524		return nil, err
525	}
526	return &spannerpb.ResultSet{
527		Stats: &spannerpb.ResultSetStats{
528			RowCount: &spannerpb.ResultSetStats_RowCountExact{int64(n)},
529		},
530	}, nil
531}
532
533func (s *server) ExecuteStreamingSql(req *spannerpb.ExecuteSqlRequest, stream spannerpb.Spanner_ExecuteStreamingSqlServer) error {
534	tx, cleanup, err := s.readTx(stream.Context(), req.Session, req.Transaction)
535	if err != nil {
536		return err
537	}
538	defer cleanup()
539
540	ri, err := s.executeQuery(req)
541	if err != nil {
542		return err
543	}
544	return s.readStream(stream.Context(), tx, stream.Send, ri)
545}
546
547func (s *server) executeQuery(req *spannerpb.ExecuteSqlRequest) (ri rowIter, err error) {
548	q, err := spansql.ParseQuery(req.Sql)
549	if err != nil {
550		// TODO: check what code the real Spanner returns here.
551		return nil, status.Errorf(codes.InvalidArgument, "bad query: %v", err)
552	}
553
554	params, err := parseQueryParams(req.GetParams(), req.ParamTypes)
555	if err != nil {
556		return nil, err
557	}
558
559	s.logf("Querying: %s", q.SQL())
560	if len(params) > 0 {
561		s.logf("        ▹ %v", params)
562	}
563
564	return s.db.Query(q, params)
565}
566
567// TODO: Read
568
569func (s *server) StreamingRead(req *spannerpb.ReadRequest, stream spannerpb.Spanner_StreamingReadServer) error {
570	tx, cleanup, err := s.readTx(stream.Context(), req.Session, req.Transaction)
571	if err != nil {
572		return err
573	}
574	defer cleanup()
575
576	// Bail out if various advanced features are being used.
577	if req.Index != "" {
578		// This is okay; we can still return results.
579		s.logf("Warning: index reads (%q) not supported", req.Index)
580	}
581	if len(req.ResumeToken) > 0 {
582		// This should only happen if we send resume_token ourselves.
583		return fmt.Errorf("read resumption not supported")
584	}
585	if len(req.PartitionToken) > 0 {
586		return fmt.Errorf("partition restrictions not supported")
587	}
588
589	var ri rowIter
590	if req.KeySet.All {
591		s.logf("Reading all from %s (cols: %v)", req.Table, req.Columns)
592		ri, err = s.db.ReadAll(spansql.ID(req.Table), idList(req.Columns), req.Limit)
593	} else {
594		s.logf("Reading rows from %d keys and %d ranges from %s (cols: %v)", len(req.KeySet.Keys), len(req.KeySet.Ranges), req.Table, req.Columns)
595		ri, err = s.db.Read(spansql.ID(req.Table), idList(req.Columns), req.KeySet.Keys, makeKeyRangeList(req.KeySet.Ranges), req.Limit)
596	}
597	if err != nil {
598		return err
599	}
600
601	// TODO: Figure out the right contexts to use here. There's the session one (sess.ctx),
602	// but also this specific RPC one (stream.Context()). Which takes precedence?
603	// They appear to be independent.
604
605	return s.readStream(stream.Context(), tx, stream.Send, ri)
606}
607
608func (s *server) resultSet(ri rowIter) (*spannerpb.ResultSet, error) {
609	rsm, err := s.buildResultSetMetadata(ri)
610	if err != nil {
611		return nil, err
612	}
613	rs := &spannerpb.ResultSet{
614		Metadata: rsm,
615	}
616	for {
617		row, err := ri.Next()
618		if err == io.EOF {
619			break
620		} else if err != nil {
621			return nil, err
622		}
623
624		values := make([]*structpb.Value, len(row))
625		for i, x := range row {
626			v, err := spannerValueFromValue(x)
627			if err != nil {
628				return nil, err
629			}
630			values[i] = v
631		}
632		rs.Rows = append(rs.Rows, &structpb.ListValue{Values: values})
633	}
634	return rs, nil
635}
636
637func (s *server) readStream(ctx context.Context, tx *transaction, send func(*spannerpb.PartialResultSet) error, ri rowIter) error {
638	rsm, err := s.buildResultSetMetadata(ri)
639	if err != nil {
640		return err
641	}
642
643	for {
644		row, err := ri.Next()
645		if err == io.EOF {
646			break
647		} else if err != nil {
648			return err
649		}
650
651		values := make([]*structpb.Value, len(row))
652		for i, x := range row {
653			v, err := spannerValueFromValue(x)
654			if err != nil {
655				return err
656			}
657			values[i] = v
658		}
659
660		prs := &spannerpb.PartialResultSet{
661			Metadata: rsm,
662			Values:   values,
663		}
664		if err := send(prs); err != nil {
665			return err
666		}
667
668		// ResultSetMetadata is only set for the first PartialResultSet.
669		rsm = nil
670	}
671
672	return nil
673}
674
675func (s *server) buildResultSetMetadata(ri rowIter) (*spannerpb.ResultSetMetadata, error) {
676	// Build the result set metadata.
677	rsm := &spannerpb.ResultSetMetadata{
678		RowType: &spannerpb.StructType{},
679		// TODO: transaction info?
680	}
681	for _, ci := range ri.Cols() {
682		st, err := spannerTypeFromType(ci.Type)
683		if err != nil {
684			return nil, err
685		}
686		rsm.RowType.Fields = append(rsm.RowType.Fields, &spannerpb.StructType_Field{
687			Name: string(ci.Name),
688			Type: st,
689		})
690	}
691	return rsm, nil
692}
693
694func (s *server) BeginTransaction(ctx context.Context, req *spannerpb.BeginTransactionRequest) (*spannerpb.Transaction, error) {
695	//s.logf("BeginTransaction(%v)", req)
696
697	s.mu.Lock()
698	sess, ok := s.sessions[req.Session]
699	s.mu.Unlock()
700	if !ok {
701		// TODO: what error does the real Spanner return?
702		return nil, status.Errorf(codes.NotFound, "unknown session %q", req.Session)
703	}
704
705	id := genRandomTransaction()
706	tx := s.db.NewTransaction()
707
708	sess.mu.Lock()
709	sess.lastUse = time.Now()
710	sess.transactions[id] = tx
711	sess.mu.Unlock()
712
713	tr := &spannerpb.Transaction{Id: []byte(id)}
714
715	if req.GetOptions().GetReadOnly().GetReturnReadTimestamp() {
716		// Return the last commit timestamp.
717		// This isn't wholly accurate, but may be good enough for simple use cases.
718		tr.ReadTimestamp = timestampProto(s.db.LastCommitTimestamp())
719	}
720
721	return tr, nil
722}
723
724func (s *server) Commit(ctx context.Context, req *spannerpb.CommitRequest) (resp *spannerpb.CommitResponse, err error) {
725	//s.logf("Commit(%q, %q)", req.Session, req.Transaction)
726
727	obj, ok := req.Transaction.(*spannerpb.CommitRequest_TransactionId)
728	if !ok {
729		return nil, fmt.Errorf("unsupported transaction type %T", req.Transaction)
730	}
731	tid := string(obj.TransactionId)
732
733	tx, err := s.popTx(req.Session, tid)
734	if err != nil {
735		return nil, err
736	}
737	defer func() {
738		if err != nil {
739			tx.Rollback()
740		}
741	}()
742	tx.Start()
743
744	for _, m := range req.Mutations {
745		switch op := m.Operation.(type) {
746		default:
747			return nil, fmt.Errorf("unsupported mutation operation type %T", op)
748		case *spannerpb.Mutation_Insert:
749			ins := op.Insert
750			err := s.db.Insert(tx, spansql.ID(ins.Table), idList(ins.Columns), ins.Values)
751			if err != nil {
752				return nil, err
753			}
754		case *spannerpb.Mutation_Update:
755			up := op.Update
756			err := s.db.Update(tx, spansql.ID(up.Table), idList(up.Columns), up.Values)
757			if err != nil {
758				return nil, err
759			}
760		case *spannerpb.Mutation_InsertOrUpdate:
761			iou := op.InsertOrUpdate
762			err := s.db.InsertOrUpdate(tx, spansql.ID(iou.Table), idList(iou.Columns), iou.Values)
763			if err != nil {
764				return nil, err
765			}
766		case *spannerpb.Mutation_Delete_:
767			del := op.Delete
768			ks := del.KeySet
769
770			err := s.db.Delete(tx, spansql.ID(del.Table), ks.Keys, makeKeyRangeList(ks.Ranges), ks.All)
771			if err != nil {
772				return nil, err
773			}
774		}
775
776	}
777
778	ts, err := tx.Commit()
779	if err != nil {
780		return nil, err
781	}
782
783	return &spannerpb.CommitResponse{
784		CommitTimestamp: timestampProto(ts),
785	}, nil
786}
787
788func (s *server) Rollback(ctx context.Context, req *spannerpb.RollbackRequest) (*emptypb.Empty, error) {
789	s.logf("Rollback(%v)", req)
790
791	tx, err := s.popTx(req.Session, string(req.TransactionId))
792	if err != nil {
793		return nil, err
794	}
795
796	tx.Rollback()
797
798	return &emptypb.Empty{}, nil
799}
800
801// TODO: PartitionQuery, PartitionRead
802
803func parseQueryParams(p *structpb.Struct, types map[string]*spannerpb.Type) (queryParams, error) {
804	params := make(queryParams)
805	for k, v := range p.GetFields() {
806		p, err := parseQueryParam(v, types[k])
807		if err != nil {
808			return nil, err
809		}
810		params[k] = p
811	}
812	return params, nil
813}
814
815func parseQueryParam(v *structpb.Value, typ *spannerpb.Type) (queryParam, error) {
816	// TODO: Use valForType and typeFromSpannerType more comprehensively here?
817	// They are only used for StringValue vs, since that's what mostly needs parsing.
818
819	rawv := v
820	switch v := v.Kind.(type) {
821	default:
822		return queryParam{}, fmt.Errorf("unsupported well-known type value kind %T", v)
823	case *structpb.Value_NullValue:
824		return queryParam{Value: nil}, nil // TODO: set a type?
825	case *structpb.Value_BoolValue:
826		return queryParam{Value: v.BoolValue, Type: boolType}, nil
827	case *structpb.Value_NumberValue:
828		return queryParam{Value: v.NumberValue, Type: float64Type}, nil
829	case *structpb.Value_StringValue:
830		t, err := typeFromSpannerType(typ)
831		if err != nil {
832			return queryParam{}, err
833		}
834		val, err := valForType(rawv, t)
835		if err != nil {
836			return queryParam{}, err
837		}
838		return queryParam{Value: val, Type: t}, nil
839	case *structpb.Value_ListValue:
840		var list []interface{}
841		for _, elem := range v.ListValue.Values {
842			// TODO: Change the type parameter passed through? We only look at the code.
843			p, err := parseQueryParam(elem, typ)
844			if err != nil {
845				return queryParam{}, err
846			}
847			list = append(list, p.Value)
848		}
849		t, err := typeFromSpannerType(typ)
850		if err != nil {
851			return queryParam{}, err
852		}
853		return queryParam{Value: list, Type: t}, nil
854	}
855}
856
857func typeFromSpannerType(st *spannerpb.Type) (spansql.Type, error) {
858	switch st.Code {
859	default:
860		return spansql.Type{}, fmt.Errorf("unhandled spanner type code %v", st.Code)
861	case spannerpb.TypeCode_BOOL:
862		return spansql.Type{Base: spansql.Bool}, nil
863	case spannerpb.TypeCode_INT64:
864		return spansql.Type{Base: spansql.Int64}, nil
865	case spannerpb.TypeCode_FLOAT64:
866		return spansql.Type{Base: spansql.Float64}, nil
867	case spannerpb.TypeCode_TIMESTAMP:
868		return spansql.Type{Base: spansql.Timestamp}, nil
869	case spannerpb.TypeCode_DATE:
870		return spansql.Type{Base: spansql.Date}, nil
871	case spannerpb.TypeCode_STRING:
872		return spansql.Type{Base: spansql.String}, nil // no len
873	case spannerpb.TypeCode_BYTES:
874		return spansql.Type{Base: spansql.Bytes}, nil // no len
875	case spannerpb.TypeCode_ARRAY:
876		typ, err := typeFromSpannerType(st.ArrayElementType)
877		if err != nil {
878			return spansql.Type{}, err
879		}
880		typ.Array = true
881		return typ, nil
882	}
883}
884
885func spannerTypeFromType(typ spansql.Type) (*spannerpb.Type, error) {
886	var code spannerpb.TypeCode
887	switch typ.Base {
888	default:
889		return nil, fmt.Errorf("unhandled base type %d", typ.Base)
890	case spansql.Bool:
891		code = spannerpb.TypeCode_BOOL
892	case spansql.Int64:
893		code = spannerpb.TypeCode_INT64
894	case spansql.Float64:
895		code = spannerpb.TypeCode_FLOAT64
896	case spansql.String:
897		code = spannerpb.TypeCode_STRING
898	case spansql.Bytes:
899		code = spannerpb.TypeCode_BYTES
900	case spansql.Date:
901		code = spannerpb.TypeCode_DATE
902	case spansql.Timestamp:
903		code = spannerpb.TypeCode_TIMESTAMP
904	}
905	st := &spannerpb.Type{Code: code}
906	if typ.Array {
907		st = &spannerpb.Type{
908			Code:             spannerpb.TypeCode_ARRAY,
909			ArrayElementType: st,
910		}
911	}
912	return st, nil
913}
914
915func spannerValueFromValue(x interface{}) (*structpb.Value, error) {
916	switch x := x.(type) {
917	default:
918		return nil, fmt.Errorf("unhandled database value type %T", x)
919	case bool:
920		return &structpb.Value{Kind: &structpb.Value_BoolValue{x}}, nil
921	case int64:
922		// The Spanner int64 is actually a decimal string.
923		s := strconv.FormatInt(x, 10)
924		return &structpb.Value{Kind: &structpb.Value_StringValue{s}}, nil
925	case float64:
926		return &structpb.Value{Kind: &structpb.Value_NumberValue{x}}, nil
927	case string:
928		return &structpb.Value{Kind: &structpb.Value_StringValue{x}}, nil
929	case []byte:
930		return &structpb.Value{Kind: &structpb.Value_StringValue{base64.StdEncoding.EncodeToString(x)}}, nil
931	case civil.Date:
932		// RFC 3339 date format.
933		return &structpb.Value{Kind: &structpb.Value_StringValue{x.String()}}, nil
934	case time.Time:
935		// RFC 3339 timestamp format with zone Z.
936		s := x.Format("2006-01-02T15:04:05.999999999Z")
937		return &structpb.Value{Kind: &structpb.Value_StringValue{s}}, nil
938	case nil:
939		return &structpb.Value{Kind: &structpb.Value_NullValue{}}, nil
940	case []interface{}:
941		var vs []*structpb.Value
942		for _, elem := range x {
943			v, err := spannerValueFromValue(elem)
944			if err != nil {
945				return nil, err
946			}
947			vs = append(vs, v)
948		}
949		return &structpb.Value{Kind: &structpb.Value_ListValue{
950			&structpb.ListValue{Values: vs},
951		}}, nil
952	}
953}
954
955func makeKeyRangeList(ranges []*spannerpb.KeyRange) keyRangeList {
956	var krl keyRangeList
957	for _, r := range ranges {
958		krl = append(krl, makeKeyRange(r))
959	}
960	return krl
961}
962
963func makeKeyRange(r *spannerpb.KeyRange) *keyRange {
964	var kr keyRange
965	switch s := r.StartKeyType.(type) {
966	case *spannerpb.KeyRange_StartClosed:
967		kr.start = s.StartClosed
968		kr.startClosed = true
969	case *spannerpb.KeyRange_StartOpen:
970		kr.start = s.StartOpen
971	}
972	switch e := r.EndKeyType.(type) {
973	case *spannerpb.KeyRange_EndClosed:
974		kr.end = e.EndClosed
975		kr.endClosed = true
976	case *spannerpb.KeyRange_EndOpen:
977		kr.end = e.EndOpen
978	}
979	return &kr
980}
981
982func idList(ss []string) (ids []spansql.ID) {
983	for _, s := range ss {
984		ids = append(ids, spansql.ID(s))
985	}
986	return
987}
988