1/*
2Copyright 2019 Google LLC
3
4Licensed under the Apache License, Version 2.0 (the "License");
5you may not use this file except in compliance with the License.
6You may obtain a copy of the License at
7
8    http://www.apache.org/licenses/LICENSE-2.0
9
10Unless required by applicable law or agreed to in writing, software
11distributed under the License is distributed on an "AS IS" BASIS,
12WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13See the License for the specific language governing permissions and
14limitations under the License.
15*/
16
17/*
18Package spannertest contains test helpers for working with Cloud Spanner.
19
20This package is EXPERIMENTAL, and is lacking many features. See the README.md
21file in this directory for more details.
22
23In-memory fake
24
25This package has an in-memory fake implementation of spanner. To use it,
26create a Server, and then connect to it with no security:
27	srv, err := spannertest.NewServer("localhost:0")
28	...
29	conn, err := grpc.DialContext(ctx, srv.Addr, grpc.WithInsecure())
30	...
31	client, err := spanner.NewClient(ctx, db, option.WithGRPCConn(conn))
32	...
33
34Alternatively, create a Server, then set the SPANNER_EMULATOR_HOST environment
35variable and use the regular spanner.NewClient:
36	srv, err := spannertest.NewServer("localhost:0")
37	...
38	os.Setenv("SPANNER_EMULATOR_HOST", srv.Addr)
39	client, err := spanner.NewClient(ctx, db)
40	...
41
42The same server also supports database admin operations for use with
43the cloud.google.com/go/spanner/admin/database/apiv1 package. This only
44simulates the existence of a single database; its name is ignored.
45*/
46package spannertest
47
48import (
49	"context"
50	"encoding/base64"
51	"fmt"
52	"io"
53	"log"
54	"math/rand"
55	"net"
56	"strconv"
57	"sync"
58	"time"
59
60	"github.com/golang/protobuf/proto"
61	"github.com/golang/protobuf/ptypes"
62	"google.golang.org/grpc"
63	"google.golang.org/grpc/codes"
64	"google.golang.org/grpc/status"
65
66	anypb "github.com/golang/protobuf/ptypes/any"
67	emptypb "github.com/golang/protobuf/ptypes/empty"
68	structpb "github.com/golang/protobuf/ptypes/struct"
69	timestamppb "github.com/golang/protobuf/ptypes/timestamp"
70	lropb "google.golang.org/genproto/googleapis/longrunning"
71	adminpb "google.golang.org/genproto/googleapis/spanner/admin/database/v1"
72	spannerpb "google.golang.org/genproto/googleapis/spanner/v1"
73
74	"cloud.google.com/go/civil"
75	"cloud.google.com/go/spanner/spansql"
76)
77
78// Server is an in-memory Cloud Spanner fake.
79// It is unauthenticated, non-performant, and only a rough approximation.
80type Server struct {
81	Addr string
82
83	l   net.Listener
84	srv *grpc.Server
85	s   *server
86}
87
88// server is the real implementation of the fake.
89// It is a separate and unexported type so the API won't be cluttered with
90// methods that are only relevant to the fake's implementation.
91type server struct {
92	logf Logger
93
94	db    database
95	start time.Time
96
97	mu       sync.Mutex
98	sessions map[string]*session
99	lros     map[string]*lro
100
101	// Any unimplemented methods will cause a panic.
102	// TODO: Switch to Unimplemented at some point? spannerpb would need regenerating.
103	adminpb.DatabaseAdminServer
104	spannerpb.SpannerServer
105	lropb.OperationsServer
106}
107
108type session struct {
109	name     string
110	creation time.Time
111
112	// This context tracks the lifetime of this session.
113	// It is canceled in DeleteSession.
114	ctx    context.Context
115	cancel func()
116
117	mu           sync.Mutex
118	lastUse      time.Time
119	transactions map[string]*transaction
120}
121
122func (s *session) Proto() *spannerpb.Session {
123	s.mu.Lock()
124	defer s.mu.Unlock()
125	m := &spannerpb.Session{
126		Name:                   s.name,
127		CreateTime:             timestampProto(s.creation),
128		ApproximateLastUseTime: timestampProto(s.lastUse),
129	}
130	return m
131}
132
133// timestampProto returns a valid timestamp.Timestamp,
134// or nil if the given time is zero or isn't representable.
135func timestampProto(t time.Time) *timestamppb.Timestamp {
136	if t.IsZero() {
137		return nil
138	}
139	ts, err := ptypes.TimestampProto(t)
140	if err != nil {
141		return nil
142	}
143	return ts
144}
145
146// lro represents a Long-Running Operation, generally a schema change.
147type lro struct {
148	mu      sync.Mutex
149	state   *lropb.Operation
150	waiters bool // Whether anyone appears to be waiting.
151}
152
153func (l *lro) State() *lropb.Operation {
154	l.mu.Lock()
155	defer l.mu.Unlock()
156	return proto.Clone(l.state).(*lropb.Operation)
157}
158
159// Logger is something that can be used for logging.
160// It is matched by log.Printf and testing.T.Logf.
161type Logger func(format string, args ...interface{})
162
163// NewServer creates a new Server.
164// The Server will be listening for gRPC connections, without TLS, on the provided TCP address.
165// The resolved address is available in the Addr field.
166func NewServer(laddr string) (*Server, error) {
167	l, err := net.Listen("tcp", laddr)
168	if err != nil {
169		return nil, err
170	}
171
172	s := &Server{
173		Addr: l.Addr().String(),
174		l:    l,
175		srv:  grpc.NewServer(),
176		s: &server{
177			logf: func(format string, args ...interface{}) {
178				log.Printf("spannertest.inmem: "+format, args...)
179			},
180			start:    time.Now(),
181			sessions: make(map[string]*session),
182			lros:     make(map[string]*lro),
183		},
184	}
185	adminpb.RegisterDatabaseAdminServer(s.srv, s.s)
186	spannerpb.RegisterSpannerServer(s.srv, s.s)
187	lropb.RegisterOperationsServer(s.srv, s.s)
188
189	go s.srv.Serve(s.l)
190
191	return s, nil
192}
193
194// SetLogger sets a logger for the server.
195// You can use a *testing.T as this argument to collate extra information
196// from the execution of the server.
197func (s *Server) SetLogger(l Logger) { s.s.logf = l }
198
199// Close shuts down the server.
200func (s *Server) Close() {
201	s.srv.Stop()
202	s.l.Close()
203}
204
205func genRandomSession() string {
206	var b [4]byte
207	rand.Read(b[:])
208	return fmt.Sprintf("%x", b)
209}
210
211func genRandomTransaction() string {
212	var b [6]byte
213	rand.Read(b[:])
214	return fmt.Sprintf("tx-%x", b)
215}
216
217func genRandomOperation() string {
218	var b [3]byte
219	rand.Read(b[:])
220	return fmt.Sprintf("op-%x", b)
221}
222
223func (s *server) GetOperation(ctx context.Context, req *lropb.GetOperationRequest) (*lropb.Operation, error) {
224	s.mu.Lock()
225	lro, ok := s.lros[req.Name]
226	s.mu.Unlock()
227	if !ok {
228		return nil, status.Errorf(codes.NotFound, "unknown LRO %q", req.Name)
229	}
230
231	// Someone is waiting on this LRO. Disable sleeping in its Run method.
232	lro.mu.Lock()
233	lro.waiters = true
234	lro.mu.Unlock()
235
236	return lro.State(), nil
237}
238
239func (s *server) GetDatabase(ctx context.Context, req *adminpb.GetDatabaseRequest) (*adminpb.Database, error) {
240	s.logf("GetDatabase(%q)", req.Name)
241
242	return &adminpb.Database{
243		Name:       req.Name,
244		State:      adminpb.Database_READY,
245		CreateTime: timestampProto(s.start),
246	}, nil
247}
248
249// UpdateDDL applies the given DDL to the server.
250//
251// This is a convenience method for tests that may assume an existing schema.
252// The more general approach is to dial this server using an admin client, and
253// use the UpdateDatabaseDdl RPC method.
254func (s *Server) UpdateDDL(ddl *spansql.DDL) error {
255	ctx := context.Background()
256	for _, stmt := range ddl.List {
257		if st := s.s.runOneDDL(ctx, stmt); st.Code() != codes.OK {
258			return st.Err()
259		}
260	}
261	return nil
262}
263
264func (s *server) UpdateDatabaseDdl(ctx context.Context, req *adminpb.UpdateDatabaseDdlRequest) (*lropb.Operation, error) {
265	// Parse all the DDL statements first.
266	var stmts []spansql.DDLStmt
267	for _, s := range req.Statements {
268		stmt, err := spansql.ParseDDLStmt(s)
269		if err != nil {
270			// TODO: check what code the real Spanner returns here.
271			return nil, status.Errorf(codes.InvalidArgument, "bad DDL statement %q: %v", s, err)
272		}
273		stmts = append(stmts, stmt)
274	}
275
276	// Nothing should be depending on the exact structure of this,
277	// but it is specified in google/spanner/admin/database/v1/spanner_database_admin.proto.
278	id := "projects/fake-proj/instances/fake-instance/databases/fake-db/operations/" + genRandomOperation()
279	lro := &lro{
280		state: &lropb.Operation{
281			Name: id,
282		},
283	}
284	s.mu.Lock()
285	s.lros[id] = lro
286	s.mu.Unlock()
287
288	go lro.Run(s, stmts)
289	return lro.State(), nil
290}
291
292func (l *lro) Run(s *server, stmts []spansql.DDLStmt) {
293	ctx := context.Background()
294
295	for _, stmt := range stmts {
296		l.mu.Lock()
297		waiters := l.waiters
298		l.mu.Unlock()
299		if !waiters {
300			// Simulate delayed DDL application, but only if nobody is waiting.
301			time.Sleep(100 * time.Millisecond)
302		}
303
304		if st := s.runOneDDL(ctx, stmt); st.Code() != codes.OK {
305			l.mu.Lock()
306			l.state.Done = true
307			l.state.Result = &lropb.Operation_Error{st.Proto()}
308			l.mu.Unlock()
309			return
310		}
311	}
312
313	l.mu.Lock()
314	l.state.Done = true
315	l.state.Result = &lropb.Operation_Response{&anypb.Any{}}
316	l.mu.Unlock()
317}
318
319func (s *server) runOneDDL(ctx context.Context, stmt spansql.DDLStmt) *status.Status {
320	return s.db.ApplyDDL(stmt)
321}
322
323func (s *server) GetDatabaseDdl(ctx context.Context, req *adminpb.GetDatabaseDdlRequest) (*adminpb.GetDatabaseDdlResponse, error) {
324	s.logf("GetDatabaseDdl(%q)", req.Database)
325
326	var resp adminpb.GetDatabaseDdlResponse
327	for _, stmt := range s.db.GetDDL() {
328		resp.Statements = append(resp.Statements, stmt.SQL())
329	}
330	return &resp, nil
331}
332
333func (s *server) CreateSession(ctx context.Context, req *spannerpb.CreateSessionRequest) (*spannerpb.Session, error) {
334	//s.logf("CreateSession(%q)", req.Database)
335	return s.newSession(), nil
336}
337
338func (s *server) newSession() *spannerpb.Session {
339	id := genRandomSession()
340	now := time.Now()
341	sess := &session{
342		name:         id,
343		creation:     now,
344		lastUse:      now,
345		transactions: make(map[string]*transaction),
346	}
347	sess.ctx, sess.cancel = context.WithCancel(context.Background())
348
349	s.mu.Lock()
350	s.sessions[id] = sess
351	s.mu.Unlock()
352
353	return sess.Proto()
354}
355
356func (s *server) BatchCreateSessions(ctx context.Context, req *spannerpb.BatchCreateSessionsRequest) (*spannerpb.BatchCreateSessionsResponse, error) {
357	//s.logf("BatchCreateSessions(%q)", req.Database)
358
359	var sessions []*spannerpb.Session
360	for i := int32(0); i < req.GetSessionCount(); i++ {
361		sessions = append(sessions, s.newSession())
362	}
363
364	return &spannerpb.BatchCreateSessionsResponse{Session: sessions}, nil
365}
366
367func (s *server) GetSession(ctx context.Context, req *spannerpb.GetSessionRequest) (*spannerpb.Session, error) {
368	s.mu.Lock()
369	sess, ok := s.sessions[req.Name]
370	s.mu.Unlock()
371
372	if !ok {
373		// TODO: what error does the real Spanner return?
374		return nil, status.Errorf(codes.NotFound, "unknown session %q", req.Name)
375	}
376
377	return sess.Proto(), nil
378}
379
380// TODO: ListSessions
381
382func (s *server) DeleteSession(ctx context.Context, req *spannerpb.DeleteSessionRequest) (*emptypb.Empty, error) {
383	//s.logf("DeleteSession(%q)", req.Name)
384
385	s.mu.Lock()
386	sess, ok := s.sessions[req.Name]
387	delete(s.sessions, req.Name)
388	s.mu.Unlock()
389
390	if !ok {
391		// TODO: what error does the real Spanner return?
392		return nil, status.Errorf(codes.NotFound, "unknown session %q", req.Name)
393	}
394
395	// Terminate any operations in this session.
396	sess.cancel()
397
398	return &emptypb.Empty{}, nil
399}
400
401// popTx returns an existing transaction, removing it from the session.
402// This is called when a transaction is finishing (Commit, Rollback).
403func (s *server) popTx(sessionID, tid string) (tx *transaction, err error) {
404	s.mu.Lock()
405	sess, ok := s.sessions[sessionID]
406	s.mu.Unlock()
407	if !ok {
408		// TODO: what error does the real Spanner return?
409		return nil, status.Errorf(codes.NotFound, "unknown session %q", sessionID)
410	}
411
412	sess.mu.Lock()
413	sess.lastUse = time.Now()
414	tx, ok = sess.transactions[tid]
415	if ok {
416		delete(sess.transactions, tid)
417	}
418	sess.mu.Unlock()
419	if !ok {
420		// TODO: what error does the real Spanner return?
421		return nil, status.Errorf(codes.NotFound, "unknown transaction ID %q", tid)
422	}
423	return tx, nil
424}
425
426// readTx returns a transaction for the given session and transaction selector.
427// It is used by read/query operations (ExecuteStreamingSql, StreamingRead).
428func (s *server) readTx(ctx context.Context, session string, tsel *spannerpb.TransactionSelector) (tx *transaction, cleanup func(), err error) {
429	s.mu.Lock()
430	sess, ok := s.sessions[session]
431	s.mu.Unlock()
432	if !ok {
433		// TODO: what error does the real Spanner return?
434		return nil, nil, status.Errorf(codes.NotFound, "unknown session %q", session)
435	}
436
437	sess.mu.Lock()
438	sess.lastUse = time.Now()
439	sess.mu.Unlock()
440
441	// Only give a read-only transaction regardless of whether the selector
442	// is requesting a read-write or read-only one, since this is in readTx
443	// and so shouldn't be mutating anyway.
444	singleUse := func() (*transaction, func(), error) {
445		tx := s.db.NewReadOnlyTransaction()
446		return tx, tx.Rollback, nil
447	}
448
449	if tsel.GetSelector() == nil {
450		return singleUse()
451	}
452
453	switch sel := tsel.Selector.(type) {
454	default:
455		return nil, nil, fmt.Errorf("TransactionSelector type %T not supported", sel)
456	case *spannerpb.TransactionSelector_SingleUse:
457		// Ignore options (e.g. timestamps).
458		switch mode := sel.SingleUse.Mode.(type) {
459		case *spannerpb.TransactionOptions_ReadOnly_:
460			return singleUse()
461		case *spannerpb.TransactionOptions_ReadWrite_:
462			return singleUse()
463		default:
464			return nil, nil, fmt.Errorf("single use transaction in mode %T not supported", mode)
465		}
466	case *spannerpb.TransactionSelector_Id:
467		sess.mu.Lock()
468		tx, ok := sess.transactions[string(sel.Id)]
469		sess.mu.Unlock()
470		if !ok {
471			return nil, nil, fmt.Errorf("no transaction with id %q", sel.Id)
472		}
473		return tx, func() {}, nil
474	}
475}
476
477func (s *server) ExecuteSql(ctx context.Context, req *spannerpb.ExecuteSqlRequest) (*spannerpb.ResultSet, error) {
478	// Assume this is probably a DML statement. Queries tend to use ExecuteStreamingSql.
479	// TODO: Expand this to support more things.
480
481	obj, ok := req.Transaction.Selector.(*spannerpb.TransactionSelector_Id)
482	if !ok {
483		return nil, fmt.Errorf("unsupported transaction type %T", req.Transaction.Selector)
484	}
485	tid := string(obj.Id)
486	_ = tid // TODO: lookup an existing transaction by ID.
487
488	stmt, err := spansql.ParseDMLStmt(req.Sql)
489	if err != nil {
490		return nil, status.Errorf(codes.InvalidArgument, "bad DML: %v", err)
491	}
492	params, err := parseQueryParams(req.GetParams(), req.ParamTypes)
493	if err != nil {
494		return nil, err
495	}
496
497	s.logf("Executing: %s", stmt.SQL())
498	if len(params) > 0 {
499		s.logf("        ▹ %v", params)
500	}
501
502	n, err := s.db.Execute(stmt, params)
503	if err != nil {
504		return nil, err
505	}
506	return &spannerpb.ResultSet{
507		Stats: &spannerpb.ResultSetStats{
508			RowCount: &spannerpb.ResultSetStats_RowCountExact{int64(n)},
509		},
510	}, nil
511}
512
513func (s *server) ExecuteStreamingSql(req *spannerpb.ExecuteSqlRequest, stream spannerpb.Spanner_ExecuteStreamingSqlServer) error {
514	tx, cleanup, err := s.readTx(stream.Context(), req.Session, req.Transaction)
515	if err != nil {
516		return err
517	}
518	defer cleanup()
519
520	q, err := spansql.ParseQuery(req.Sql)
521	if err != nil {
522		// TODO: check what code the real Spanner returns here.
523		return status.Errorf(codes.InvalidArgument, "bad query: %v", err)
524	}
525
526	params, err := parseQueryParams(req.GetParams(), req.ParamTypes)
527	if err != nil {
528		return err
529	}
530
531	s.logf("Querying: %s", q.SQL())
532	if len(params) > 0 {
533		s.logf("        ▹ %v", params)
534	}
535
536	ri, err := s.db.Query(q, params)
537	if err != nil {
538		return err
539	}
540	return s.readStream(stream.Context(), tx, stream.Send, ri)
541}
542
543// TODO: Read
544
545func (s *server) StreamingRead(req *spannerpb.ReadRequest, stream spannerpb.Spanner_StreamingReadServer) error {
546	tx, cleanup, err := s.readTx(stream.Context(), req.Session, req.Transaction)
547	if err != nil {
548		return err
549	}
550	defer cleanup()
551
552	// Bail out if various advanced features are being used.
553	if req.Index != "" {
554		// This is okay; we can still return results.
555		s.logf("Warning: index reads (%q) not supported", req.Index)
556	}
557	if len(req.ResumeToken) > 0 {
558		// This should only happen if we send resume_token ourselves.
559		return fmt.Errorf("read resumption not supported")
560	}
561	if len(req.PartitionToken) > 0 {
562		return fmt.Errorf("partition restrictions not supported")
563	}
564
565	var ri rowIter
566	if req.KeySet.All {
567		s.logf("Reading all from %s (cols: %v)", req.Table, req.Columns)
568		ri, err = s.db.ReadAll(spansql.ID(req.Table), idList(req.Columns), req.Limit)
569	} else {
570		s.logf("Reading rows from %d keys and %d ranges from %s (cols: %v)", len(req.KeySet.Keys), len(req.KeySet.Ranges), req.Table, req.Columns)
571		ri, err = s.db.Read(spansql.ID(req.Table), idList(req.Columns), req.KeySet.Keys, makeKeyRangeList(req.KeySet.Ranges), req.Limit)
572	}
573	if err != nil {
574		return err
575	}
576
577	// TODO: Figure out the right contexts to use here. There's the session one (sess.ctx),
578	// but also this specific RPC one (stream.Context()). Which takes precedence?
579	// They appear to be independent.
580
581	return s.readStream(stream.Context(), tx, stream.Send, ri)
582}
583
584func (s *server) readStream(ctx context.Context, tx *transaction, send func(*spannerpb.PartialResultSet) error, ri rowIter) error {
585	// Build the result set metadata.
586	rsm := &spannerpb.ResultSetMetadata{
587		RowType: &spannerpb.StructType{},
588		// TODO: transaction info?
589	}
590	for _, ci := range ri.Cols() {
591		st, err := spannerTypeFromType(ci.Type)
592		if err != nil {
593			return err
594		}
595		rsm.RowType.Fields = append(rsm.RowType.Fields, &spannerpb.StructType_Field{
596			Name: string(ci.Name),
597			Type: st,
598		})
599	}
600
601	for {
602		row, err := ri.Next()
603		if err == io.EOF {
604			break
605		} else if err != nil {
606			return err
607		}
608
609		values := make([]*structpb.Value, len(row))
610		for i, x := range row {
611			v, err := spannerValueFromValue(x)
612			if err != nil {
613				return err
614			}
615			values[i] = v
616		}
617
618		prs := &spannerpb.PartialResultSet{
619			Metadata: rsm,
620			Values:   values,
621		}
622		if err := send(prs); err != nil {
623			return err
624		}
625
626		// ResultSetMetadata is only set for the first PartialResultSet.
627		rsm = nil
628	}
629
630	return nil
631}
632
633func (s *server) BeginTransaction(ctx context.Context, req *spannerpb.BeginTransactionRequest) (*spannerpb.Transaction, error) {
634	//s.logf("BeginTransaction(%v)", req)
635
636	s.mu.Lock()
637	sess, ok := s.sessions[req.Session]
638	s.mu.Unlock()
639	if !ok {
640		// TODO: what error does the real Spanner return?
641		return nil, status.Errorf(codes.NotFound, "unknown session %q", req.Session)
642	}
643
644	id := genRandomTransaction()
645	tx := s.db.NewTransaction()
646
647	sess.mu.Lock()
648	sess.lastUse = time.Now()
649	sess.transactions[id] = tx
650	sess.mu.Unlock()
651
652	tr := &spannerpb.Transaction{Id: []byte(id)}
653
654	if req.GetOptions().GetReadOnly().GetReturnReadTimestamp() {
655		// Return the last commit timestamp.
656		// This isn't wholly accurate, but may be good enough for simple use cases.
657		tr.ReadTimestamp = timestampProto(s.db.LastCommitTimestamp())
658	}
659
660	return tr, nil
661}
662
663func (s *server) Commit(ctx context.Context, req *spannerpb.CommitRequest) (resp *spannerpb.CommitResponse, err error) {
664	//s.logf("Commit(%q, %q)", req.Session, req.Transaction)
665
666	obj, ok := req.Transaction.(*spannerpb.CommitRequest_TransactionId)
667	if !ok {
668		return nil, fmt.Errorf("unsupported transaction type %T", req.Transaction)
669	}
670	tid := string(obj.TransactionId)
671
672	tx, err := s.popTx(req.Session, tid)
673	if err != nil {
674		return nil, err
675	}
676	defer func() {
677		if err != nil {
678			tx.Rollback()
679		}
680	}()
681	tx.Start()
682
683	for _, m := range req.Mutations {
684		switch op := m.Operation.(type) {
685		default:
686			return nil, fmt.Errorf("unsupported mutation operation type %T", op)
687		case *spannerpb.Mutation_Insert:
688			ins := op.Insert
689			err := s.db.Insert(tx, spansql.ID(ins.Table), idList(ins.Columns), ins.Values)
690			if err != nil {
691				return nil, err
692			}
693		case *spannerpb.Mutation_Update:
694			up := op.Update
695			err := s.db.Update(tx, spansql.ID(up.Table), idList(up.Columns), up.Values)
696			if err != nil {
697				return nil, err
698			}
699		case *spannerpb.Mutation_InsertOrUpdate:
700			iou := op.InsertOrUpdate
701			err := s.db.InsertOrUpdate(tx, spansql.ID(iou.Table), idList(iou.Columns), iou.Values)
702			if err != nil {
703				return nil, err
704			}
705		case *spannerpb.Mutation_Delete_:
706			del := op.Delete
707			ks := del.KeySet
708
709			err := s.db.Delete(tx, spansql.ID(del.Table), ks.Keys, makeKeyRangeList(ks.Ranges), ks.All)
710			if err != nil {
711				return nil, err
712			}
713		}
714
715	}
716
717	ts, err := tx.Commit()
718	if err != nil {
719		return nil, err
720	}
721
722	return &spannerpb.CommitResponse{
723		CommitTimestamp: timestampProto(ts),
724	}, nil
725}
726
727func (s *server) Rollback(ctx context.Context, req *spannerpb.RollbackRequest) (*emptypb.Empty, error) {
728	s.logf("Rollback(%v)", req)
729
730	tx, err := s.popTx(req.Session, string(req.TransactionId))
731	if err != nil {
732		return nil, err
733	}
734
735	tx.Rollback()
736
737	return &emptypb.Empty{}, nil
738}
739
740// TODO: PartitionQuery, PartitionRead
741
742func parseQueryParams(p *structpb.Struct, types map[string]*spannerpb.Type) (queryParams, error) {
743	params := make(queryParams)
744	for k, v := range p.GetFields() {
745		p, err := parseQueryParam(v, types[k])
746		if err != nil {
747			return nil, err
748		}
749		params[k] = p
750	}
751	return params, nil
752}
753
754func parseQueryParam(v *structpb.Value, typ *spannerpb.Type) (queryParam, error) {
755	// TODO: Use valForType and typeFromSpannerType more comprehensively here?
756	// They are only used for StringValue vs, since that's what mostly needs parsing.
757
758	rawv := v
759	switch v := v.Kind.(type) {
760	default:
761		return queryParam{}, fmt.Errorf("unsupported well-known type value kind %T", v)
762	case *structpb.Value_NullValue:
763		return queryParam{Value: nil}, nil // TODO: set a type?
764	case *structpb.Value_BoolValue:
765		return queryParam{Value: v.BoolValue, Type: boolType}, nil
766	case *structpb.Value_NumberValue:
767		return queryParam{Value: v.NumberValue, Type: float64Type}, nil
768	case *structpb.Value_StringValue:
769		t, err := typeFromSpannerType(typ)
770		if err != nil {
771			return queryParam{}, err
772		}
773		val, err := valForType(rawv, t)
774		if err != nil {
775			return queryParam{}, err
776		}
777		return queryParam{Value: val, Type: t}, nil
778	case *structpb.Value_ListValue:
779		var list []interface{}
780		for _, elem := range v.ListValue.Values {
781			// TODO: Change the type parameter passed through? We only look at the code.
782			p, err := parseQueryParam(elem, typ)
783			if err != nil {
784				return queryParam{}, err
785			}
786			list = append(list, p.Value)
787		}
788		t, err := typeFromSpannerType(typ)
789		if err != nil {
790			return queryParam{}, err
791		}
792		return queryParam{Value: list, Type: t}, nil
793	}
794}
795
796func typeFromSpannerType(st *spannerpb.Type) (spansql.Type, error) {
797	switch st.Code {
798	default:
799		return spansql.Type{}, fmt.Errorf("unhandled spanner type code %v", st.Code)
800	case spannerpb.TypeCode_BOOL:
801		return spansql.Type{Base: spansql.Bool}, nil
802	case spannerpb.TypeCode_INT64:
803		return spansql.Type{Base: spansql.Int64}, nil
804	case spannerpb.TypeCode_FLOAT64:
805		return spansql.Type{Base: spansql.Float64}, nil
806	case spannerpb.TypeCode_TIMESTAMP:
807		return spansql.Type{Base: spansql.Timestamp}, nil
808	case spannerpb.TypeCode_DATE:
809		return spansql.Type{Base: spansql.Date}, nil
810	case spannerpb.TypeCode_STRING:
811		return spansql.Type{Base: spansql.String}, nil // no len
812	case spannerpb.TypeCode_BYTES:
813		return spansql.Type{Base: spansql.Bytes}, nil // no len
814	case spannerpb.TypeCode_ARRAY:
815		typ, err := typeFromSpannerType(st.ArrayElementType)
816		if err != nil {
817			return spansql.Type{}, err
818		}
819		typ.Array = true
820		return typ, nil
821	}
822}
823
824func spannerTypeFromType(typ spansql.Type) (*spannerpb.Type, error) {
825	var code spannerpb.TypeCode
826	switch typ.Base {
827	default:
828		return nil, fmt.Errorf("unhandled base type %d", typ.Base)
829	case spansql.Bool:
830		code = spannerpb.TypeCode_BOOL
831	case spansql.Int64:
832		code = spannerpb.TypeCode_INT64
833	case spansql.Float64:
834		code = spannerpb.TypeCode_FLOAT64
835	case spansql.String:
836		code = spannerpb.TypeCode_STRING
837	case spansql.Bytes:
838		code = spannerpb.TypeCode_BYTES
839	case spansql.Date:
840		code = spannerpb.TypeCode_DATE
841	case spansql.Timestamp:
842		code = spannerpb.TypeCode_TIMESTAMP
843	}
844	st := &spannerpb.Type{Code: code}
845	if typ.Array {
846		st = &spannerpb.Type{
847			Code:             spannerpb.TypeCode_ARRAY,
848			ArrayElementType: st,
849		}
850	}
851	return st, nil
852}
853
854func spannerValueFromValue(x interface{}) (*structpb.Value, error) {
855	switch x := x.(type) {
856	default:
857		return nil, fmt.Errorf("unhandled database value type %T", x)
858	case bool:
859		return &structpb.Value{Kind: &structpb.Value_BoolValue{x}}, nil
860	case int64:
861		// The Spanner int64 is actually a decimal string.
862		s := strconv.FormatInt(x, 10)
863		return &structpb.Value{Kind: &structpb.Value_StringValue{s}}, nil
864	case float64:
865		return &structpb.Value{Kind: &structpb.Value_NumberValue{x}}, nil
866	case string:
867		return &structpb.Value{Kind: &structpb.Value_StringValue{x}}, nil
868	case []byte:
869		return &structpb.Value{Kind: &structpb.Value_StringValue{base64.StdEncoding.EncodeToString(x)}}, nil
870	case civil.Date:
871		// RFC 3339 date format.
872		return &structpb.Value{Kind: &structpb.Value_StringValue{x.String()}}, nil
873	case time.Time:
874		// RFC 3339 timestamp format with zone Z.
875		s := x.Format("2006-01-02T15:04:05.999999999Z")
876		return &structpb.Value{Kind: &structpb.Value_StringValue{s}}, nil
877	case nil:
878		return &structpb.Value{Kind: &structpb.Value_NullValue{}}, nil
879	case []interface{}:
880		var vs []*structpb.Value
881		for _, elem := range x {
882			v, err := spannerValueFromValue(elem)
883			if err != nil {
884				return nil, err
885			}
886			vs = append(vs, v)
887		}
888		return &structpb.Value{Kind: &structpb.Value_ListValue{
889			&structpb.ListValue{Values: vs},
890		}}, nil
891	}
892}
893
894func makeKeyRangeList(ranges []*spannerpb.KeyRange) keyRangeList {
895	var krl keyRangeList
896	for _, r := range ranges {
897		krl = append(krl, makeKeyRange(r))
898	}
899	return krl
900}
901
902func makeKeyRange(r *spannerpb.KeyRange) *keyRange {
903	var kr keyRange
904	switch s := r.StartKeyType.(type) {
905	case *spannerpb.KeyRange_StartClosed:
906		kr.start = s.StartClosed
907		kr.startClosed = true
908	case *spannerpb.KeyRange_StartOpen:
909		kr.start = s.StartOpen
910	}
911	switch e := r.EndKeyType.(type) {
912	case *spannerpb.KeyRange_EndClosed:
913		kr.end = e.EndClosed
914		kr.endClosed = true
915	case *spannerpb.KeyRange_EndOpen:
916		kr.end = e.EndOpen
917	}
918	return &kr
919}
920
921func idList(ss []string) (ids []spansql.ID) {
922	for _, s := range ss {
923		ids = append(ids, spansql.ID(s))
924	}
925	return
926}
927