1/*
2Copyright 2019 Google LLC
3
4Licensed under the Apache License, Version 2.0 (the "License");
5you may not use this file except in compliance with the License.
6You may obtain a copy of the License at
7
8    http://www.apache.org/licenses/LICENSE-2.0
9
10Unless required by applicable law or agreed to in writing, software
11distributed under the License is distributed on an "AS IS" BASIS,
12WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13See the License for the specific language governing permissions and
14limitations under the License.
15*/
16
17/*
18Package spannertest contains test helpers for working with Cloud Spanner.
19
20This package is EXPERIMENTAL, and is lacking many features. See the README.md
21file in this directory for more details.
22
23In-memory fake
24
25This package has an in-memory fake implementation of spanner. To use it,
26create a Server, and then connect to it with no security:
27	srv, err := spannertest.NewServer("localhost:0")
28	...
29	conn, err := grpc.DialContext(ctx, srv.Addr, grpc.WithInsecure())
30	...
31	client, err := spanner.NewClient(ctx, db, option.WithGRPCConn(conn))
32	...
33
34Alternatively, create a Server, then set the SPANNER_EMULATOR_HOST environment
35variable and use the regular spanner.NewClient:
36	srv, err := spannertest.NewServer("localhost:0")
37	...
38	os.Setenv("SPANNER_EMULATOR_HOST", srv.Addr)
39	client, err := spanner.NewClient(ctx, db)
40	...
41
42The same server also supports database admin operations for use with
43the cloud.google.com/go/spanner/admin/database/apiv1 package.
44*/
45package spannertest
46
47import (
48	"context"
49	"encoding/base64"
50	"fmt"
51	"io"
52	"log"
53	"math/rand"
54	"net"
55	"strconv"
56	"sync"
57	"time"
58
59	"github.com/golang/protobuf/proto"
60	"github.com/golang/protobuf/ptypes"
61	"google.golang.org/grpc"
62	"google.golang.org/grpc/codes"
63	"google.golang.org/grpc/status"
64
65	anypb "github.com/golang/protobuf/ptypes/any"
66	emptypb "github.com/golang/protobuf/ptypes/empty"
67	structpb "github.com/golang/protobuf/ptypes/struct"
68	timestamppb "github.com/golang/protobuf/ptypes/timestamp"
69	lropb "google.golang.org/genproto/googleapis/longrunning"
70	adminpb "google.golang.org/genproto/googleapis/spanner/admin/database/v1"
71	spannerpb "google.golang.org/genproto/googleapis/spanner/v1"
72
73	"cloud.google.com/go/spanner/spansql"
74)
75
76// Server is an in-memory Cloud Spanner fake.
77// It is unauthenticated, non-performant, and only a rough approximation.
78type Server struct {
79	Addr string
80
81	l   net.Listener
82	srv *grpc.Server
83	s   *server
84}
85
86// server is the real implementation of the fake.
87// It is a separate and unexported type so the API won't be cluttered with
88// methods that are only relevant to the fake's implementation.
89type server struct {
90	logf Logger
91
92	db database
93
94	mu       sync.Mutex
95	sessions map[string]*session
96	lros     map[string]*lro
97
98	// Any unimplemented methods will cause a panic.
99	// TODO: Switch to Unimplemented at some point? spannerpb would need regenerating.
100	adminpb.DatabaseAdminServer
101	spannerpb.SpannerServer
102	lropb.OperationsServer
103}
104
105type session struct {
106	name     string
107	creation time.Time
108
109	// This context tracks the lifetime of this session.
110	// It is canceled in DeleteSession.
111	ctx    context.Context
112	cancel func()
113
114	mu           sync.Mutex
115	lastUse      time.Time
116	transactions map[string]*transaction
117}
118
119func (s *session) Proto() *spannerpb.Session {
120	s.mu.Lock()
121	defer s.mu.Unlock()
122	m := &spannerpb.Session{
123		Name:                   s.name,
124		CreateTime:             timestampProto(s.creation),
125		ApproximateLastUseTime: timestampProto(s.lastUse),
126	}
127	return m
128}
129
130// timestampProto returns a valid timestamp.Timestamp,
131// or nil if the given time is zero or isn't representable.
132func timestampProto(t time.Time) *timestamppb.Timestamp {
133	if t.IsZero() {
134		return nil
135	}
136	ts, err := ptypes.TimestampProto(t)
137	if err != nil {
138		return nil
139	}
140	return ts
141}
142
143// lro represents a Long-Running Operation, generally a schema change.
144type lro struct {
145	mu      sync.Mutex
146	state   *lropb.Operation
147	waiters bool // Whether anyone appears to be waiting.
148}
149
150func (l *lro) State() *lropb.Operation {
151	l.mu.Lock()
152	defer l.mu.Unlock()
153	return proto.Clone(l.state).(*lropb.Operation)
154}
155
156// Logger is something that can be used for logging.
157// It is matched by log.Printf and testing.T.Logf.
158type Logger func(format string, args ...interface{})
159
160// NewServer creates a new Server.
161// The Server will be listening for gRPC connections, without TLS, on the provided TCP address.
162// The resolved address is available in the Addr field.
163func NewServer(laddr string) (*Server, error) {
164	l, err := net.Listen("tcp", laddr)
165	if err != nil {
166		return nil, err
167	}
168
169	s := &Server{
170		Addr: l.Addr().String(),
171		l:    l,
172		srv:  grpc.NewServer(),
173		s: &server{
174			logf: func(format string, args ...interface{}) {
175				log.Printf("spannertest.inmem: "+format, args...)
176			},
177			sessions: make(map[string]*session),
178			lros:     make(map[string]*lro),
179		},
180	}
181	adminpb.RegisterDatabaseAdminServer(s.srv, s.s)
182	spannerpb.RegisterSpannerServer(s.srv, s.s)
183	lropb.RegisterOperationsServer(s.srv, s.s)
184
185	go s.srv.Serve(s.l)
186
187	return s, nil
188}
189
190// SetLogger sets a logger for the server.
191// You can use a *testing.T as this argument to collate extra information
192// from the execution of the server.
193func (s *Server) SetLogger(l Logger) { s.s.logf = l }
194
195// Close shuts down the server.
196func (s *Server) Close() {
197	s.srv.Stop()
198	s.l.Close()
199}
200
201func genRandomSession() string {
202	var b [4]byte
203	rand.Read(b[:])
204	return fmt.Sprintf("%x", b)
205}
206
207func genRandomTransaction() string {
208	var b [6]byte
209	rand.Read(b[:])
210	return fmt.Sprintf("tx-%x", b)
211}
212
213func genRandomOperation() string {
214	var b [3]byte
215	rand.Read(b[:])
216	return fmt.Sprintf("op-%x", b)
217}
218
219func (s *server) GetOperation(ctx context.Context, req *lropb.GetOperationRequest) (*lropb.Operation, error) {
220	s.mu.Lock()
221	lro, ok := s.lros[req.Name]
222	s.mu.Unlock()
223	if !ok {
224		return nil, status.Errorf(codes.NotFound, "unknown LRO %q", req.Name)
225	}
226
227	// Someone is waiting on this LRO. Disable sleeping in its Run method.
228	lro.mu.Lock()
229	lro.waiters = true
230	lro.mu.Unlock()
231
232	return lro.State(), nil
233}
234
235// UpdateDDL applies the given DDL to the server.
236//
237// This is a convenience method for tests that may assume an existing schema.
238// The more general approach is to dial this server using an admin client, and
239// use the UpdateDatabaseDdl RPC method.
240func (s *Server) UpdateDDL(ddl *spansql.DDL) error {
241	ctx := context.Background()
242	for _, stmt := range ddl.List {
243		if st := s.s.runOneDDL(ctx, stmt); st.Code() != codes.OK {
244			return st.Err()
245		}
246	}
247	return nil
248}
249
250func (s *server) UpdateDatabaseDdl(ctx context.Context, req *adminpb.UpdateDatabaseDdlRequest) (*lropb.Operation, error) {
251	// Parse all the DDL statements first.
252	var stmts []spansql.DDLStmt
253	for _, s := range req.Statements {
254		stmt, err := spansql.ParseDDLStmt(s)
255		if err != nil {
256			// TODO: check what code the real Spanner returns here.
257			return nil, status.Errorf(codes.InvalidArgument, "bad DDL statement %q: %v", s, err)
258		}
259		stmts = append(stmts, stmt)
260	}
261
262	// Nothing should be depending on the exact structure of this,
263	// but it is specified in google/spanner/admin/database/v1/spanner_database_admin.proto.
264	id := "projects/fake-proj/instances/fake-instance/databases/fake-db/operations/" + genRandomOperation()
265	lro := &lro{
266		state: &lropb.Operation{
267			Name: id,
268		},
269	}
270	s.mu.Lock()
271	s.lros[id] = lro
272	s.mu.Unlock()
273
274	go lro.Run(s, stmts)
275	return lro.State(), nil
276}
277
278func (l *lro) Run(s *server, stmts []spansql.DDLStmt) {
279	ctx := context.Background()
280
281	for _, stmt := range stmts {
282		l.mu.Lock()
283		waiters := l.waiters
284		l.mu.Unlock()
285		if !waiters {
286			// Simulate delayed DDL application, but only if nobody is waiting.
287			time.Sleep(100 * time.Millisecond)
288		}
289
290		if st := s.runOneDDL(ctx, stmt); st.Code() != codes.OK {
291			l.mu.Lock()
292			l.state.Done = true
293			l.state.Result = &lropb.Operation_Error{st.Proto()}
294			l.mu.Unlock()
295			return
296		}
297	}
298
299	l.mu.Lock()
300	l.state.Done = true
301	l.state.Result = &lropb.Operation_Response{&anypb.Any{}}
302	l.mu.Unlock()
303}
304
305func (s *server) runOneDDL(ctx context.Context, stmt spansql.DDLStmt) *status.Status {
306	return s.db.ApplyDDL(stmt)
307}
308
309func (s *server) GetDatabaseDdl(ctx context.Context, req *adminpb.GetDatabaseDdlRequest) (*adminpb.GetDatabaseDdlResponse, error) {
310	s.logf("GetDatabaseDdl(%q)", req.Database)
311
312	var resp adminpb.GetDatabaseDdlResponse
313	for _, stmt := range s.db.GetDDL() {
314		resp.Statements = append(resp.Statements, stmt.SQL())
315	}
316	return &resp, nil
317}
318
319func (s *server) CreateSession(ctx context.Context, req *spannerpb.CreateSessionRequest) (*spannerpb.Session, error) {
320	//s.logf("CreateSession(%q)", req.Database)
321	return s.newSession(), nil
322}
323
324func (s *server) newSession() *spannerpb.Session {
325	id := genRandomSession()
326	now := time.Now()
327	sess := &session{
328		name:         id,
329		creation:     now,
330		lastUse:      now,
331		transactions: make(map[string]*transaction),
332	}
333	sess.ctx, sess.cancel = context.WithCancel(context.Background())
334
335	s.mu.Lock()
336	s.sessions[id] = sess
337	s.mu.Unlock()
338
339	return sess.Proto()
340}
341
342func (s *server) BatchCreateSessions(ctx context.Context, req *spannerpb.BatchCreateSessionsRequest) (*spannerpb.BatchCreateSessionsResponse, error) {
343	//s.logf("BatchCreateSessions(%q)", req.Database)
344
345	var sessions []*spannerpb.Session
346	for i := int32(0); i < req.GetSessionCount(); i++ {
347		sessions = append(sessions, s.newSession())
348	}
349
350	return &spannerpb.BatchCreateSessionsResponse{Session: sessions}, nil
351}
352
353func (s *server) GetSession(ctx context.Context, req *spannerpb.GetSessionRequest) (*spannerpb.Session, error) {
354	s.mu.Lock()
355	sess, ok := s.sessions[req.Name]
356	s.mu.Unlock()
357
358	if !ok {
359		// TODO: what error does the real Spanner return?
360		return nil, status.Errorf(codes.NotFound, "unknown session %q", req.Name)
361	}
362
363	return sess.Proto(), nil
364}
365
366// TODO: ListSessions
367
368func (s *server) DeleteSession(ctx context.Context, req *spannerpb.DeleteSessionRequest) (*emptypb.Empty, error) {
369	//s.logf("DeleteSession(%q)", req.Name)
370
371	s.mu.Lock()
372	sess, ok := s.sessions[req.Name]
373	delete(s.sessions, req.Name)
374	s.mu.Unlock()
375
376	if !ok {
377		// TODO: what error does the real Spanner return?
378		return nil, status.Errorf(codes.NotFound, "unknown session %q", req.Name)
379	}
380
381	// Terminate any operations in this session.
382	sess.cancel()
383
384	return &emptypb.Empty{}, nil
385}
386
387// popTx returns an existing transaction, removing it from the session.
388// This is called when a transaction is finishing (Commit, Rollback).
389func (s *server) popTx(sessionID, tid string) (tx *transaction, err error) {
390	s.mu.Lock()
391	sess, ok := s.sessions[sessionID]
392	s.mu.Unlock()
393	if !ok {
394		// TODO: what error does the real Spanner return?
395		return nil, status.Errorf(codes.NotFound, "unknown session %q", sessionID)
396	}
397
398	sess.mu.Lock()
399	sess.lastUse = time.Now()
400	tx, ok = sess.transactions[tid]
401	if ok {
402		delete(sess.transactions, tid)
403	}
404	sess.mu.Unlock()
405	if !ok {
406		// TODO: what error does the real Spanner return?
407		return nil, status.Errorf(codes.NotFound, "unknown transaction ID %q", tid)
408	}
409	return tx, nil
410}
411
412// readTx returns a transaction for the given session and transaction selector.
413// It is used by read/query operations (ExecuteStreamingSql, StreamingRead).
414func (s *server) readTx(ctx context.Context, session string, tsel *spannerpb.TransactionSelector) (tx *transaction, cleanup func(), err error) {
415	s.mu.Lock()
416	sess, ok := s.sessions[session]
417	s.mu.Unlock()
418	if !ok {
419		// TODO: what error does the real Spanner return?
420		return nil, nil, status.Errorf(codes.NotFound, "unknown session %q", session)
421	}
422
423	sess.mu.Lock()
424	sess.lastUse = time.Now()
425	sess.mu.Unlock()
426
427	// Only give a read-only transaction regardless of whether the selector
428	// is requesting a read-write or read-only one, since this is in readTx
429	// and so shouldn't be mutating anyway.
430	singleUse := func() (*transaction, func(), error) {
431		tx := s.db.NewReadOnlyTransaction()
432		return tx, tx.Rollback, nil
433	}
434
435	if tsel.GetSelector() == nil {
436		return singleUse()
437	}
438
439	switch sel := tsel.Selector.(type) {
440	default:
441		return nil, nil, fmt.Errorf("TransactionSelector type %T not supported", sel)
442	case *spannerpb.TransactionSelector_SingleUse:
443		// Ignore options (e.g. timestamps).
444		switch mode := sel.SingleUse.Mode.(type) {
445		case *spannerpb.TransactionOptions_ReadOnly_:
446			return singleUse()
447		case *spannerpb.TransactionOptions_ReadWrite_:
448			return singleUse()
449		default:
450			return nil, nil, fmt.Errorf("single use transaction in mode %T not supported", mode)
451		}
452	case *spannerpb.TransactionSelector_Id:
453		sess.mu.Lock()
454		tx, ok := sess.transactions[string(sel.Id)]
455		sess.mu.Unlock()
456		if !ok {
457			return nil, nil, fmt.Errorf("no transaction with id %q", sel.Id)
458		}
459		return tx, func() {}, nil
460	}
461}
462
463func (s *server) ExecuteSql(ctx context.Context, req *spannerpb.ExecuteSqlRequest) (*spannerpb.ResultSet, error) {
464	// Assume this is probably a DML statement. Queries tend to use ExecuteStreamingSql.
465	// TODO: Expand this to support more things.
466
467	obj, ok := req.Transaction.Selector.(*spannerpb.TransactionSelector_Id)
468	if !ok {
469		return nil, fmt.Errorf("unsupported transaction type %T", req.Transaction.Selector)
470	}
471	tid := string(obj.Id)
472	_ = tid // TODO: lookup an existing transaction by ID.
473
474	stmt, err := spansql.ParseDMLStmt(req.Sql)
475	if err != nil {
476		return nil, status.Errorf(codes.InvalidArgument, "bad DML: %v", err)
477	}
478	params, err := parseQueryParams(req.GetParams(), req.ParamTypes)
479	if err != nil {
480		return nil, err
481	}
482
483	s.logf("Executing: %s", stmt.SQL())
484	if len(params) > 0 {
485		s.logf("        ▹ %v", params)
486	}
487
488	n, err := s.db.Execute(stmt, params)
489	if err != nil {
490		return nil, err
491	}
492	return &spannerpb.ResultSet{
493		Stats: &spannerpb.ResultSetStats{
494			RowCount: &spannerpb.ResultSetStats_RowCountExact{int64(n)},
495		},
496	}, nil
497}
498
499func (s *server) ExecuteStreamingSql(req *spannerpb.ExecuteSqlRequest, stream spannerpb.Spanner_ExecuteStreamingSqlServer) error {
500	tx, cleanup, err := s.readTx(stream.Context(), req.Session, req.Transaction)
501	if err != nil {
502		return err
503	}
504	defer cleanup()
505
506	q, err := spansql.ParseQuery(req.Sql)
507	if err != nil {
508		// TODO: check what code the real Spanner returns here.
509		return status.Errorf(codes.InvalidArgument, "bad query: %v", err)
510	}
511
512	params, err := parseQueryParams(req.GetParams(), req.ParamTypes)
513	if err != nil {
514		return err
515	}
516
517	s.logf("Querying: %s", q.SQL())
518	if len(params) > 0 {
519		s.logf("        ▹ %v", params)
520	}
521
522	ri, err := s.db.Query(q, params)
523	if err != nil {
524		return err
525	}
526	return s.readStream(stream.Context(), tx, stream.Send, ri)
527}
528
529// TODO: Read
530
531func (s *server) StreamingRead(req *spannerpb.ReadRequest, stream spannerpb.Spanner_StreamingReadServer) error {
532	tx, cleanup, err := s.readTx(stream.Context(), req.Session, req.Transaction)
533	if err != nil {
534		return err
535	}
536	defer cleanup()
537
538	// Bail out if various advanced features are being used.
539	if req.Index != "" {
540		// This is okay; we can still return results.
541		s.logf("Warning: index reads (%q) not supported", req.Index)
542	}
543	if len(req.ResumeToken) > 0 {
544		// This should only happen if we send resume_token ourselves.
545		return fmt.Errorf("read resumption not supported")
546	}
547	if len(req.PartitionToken) > 0 {
548		return fmt.Errorf("partition restrictions not supported")
549	}
550
551	var ri rowIter
552	if req.KeySet.All {
553		s.logf("Reading all from %s (cols: %v)", req.Table, req.Columns)
554		ri, err = s.db.ReadAll(req.Table, req.Columns, req.Limit)
555	} else {
556		s.logf("Reading rows from %d keys and %d ranges from %s (cols: %v)", len(req.KeySet.Keys), len(req.KeySet.Ranges), req.Table, req.Columns)
557		ri, err = s.db.Read(req.Table, req.Columns, req.KeySet.Keys, makeKeyRangeList(req.KeySet.Ranges), req.Limit)
558	}
559	if err != nil {
560		return err
561	}
562
563	// TODO: Figure out the right contexts to use here. There's the session one (sess.ctx),
564	// but also this specific RPC one (stream.Context()). Which takes precedence?
565	// They appear to be independent.
566
567	return s.readStream(stream.Context(), tx, stream.Send, ri)
568}
569
570func (s *server) readStream(ctx context.Context, tx *transaction, send func(*spannerpb.PartialResultSet) error, ri rowIter) error {
571	// Build the result set metadata.
572	rsm := &spannerpb.ResultSetMetadata{
573		RowType: &spannerpb.StructType{},
574		// TODO: transaction info?
575	}
576	for _, ci := range ri.Cols() {
577		st, err := spannerTypeFromType(ci.Type)
578		if err != nil {
579			return err
580		}
581		rsm.RowType.Fields = append(rsm.RowType.Fields, &spannerpb.StructType_Field{
582			Name: ci.Name,
583			Type: st,
584		})
585	}
586
587	for {
588		row, err := ri.Next()
589		if err == io.EOF {
590			break
591		} else if err != nil {
592			return err
593		}
594
595		values := make([]*structpb.Value, len(row))
596		for i, x := range row {
597			v, err := spannerValueFromValue(x)
598			if err != nil {
599				return err
600			}
601			values[i] = v
602		}
603
604		prs := &spannerpb.PartialResultSet{
605			Metadata: rsm,
606			Values:   values,
607		}
608		if err := send(prs); err != nil {
609			return err
610		}
611
612		// ResultSetMetadata is only set for the first PartialResultSet.
613		rsm = nil
614	}
615
616	return nil
617}
618
619func (s *server) BeginTransaction(ctx context.Context, req *spannerpb.BeginTransactionRequest) (*spannerpb.Transaction, error) {
620	//s.logf("BeginTransaction(%v)", req)
621
622	s.mu.Lock()
623	sess, ok := s.sessions[req.Session]
624	s.mu.Unlock()
625	if !ok {
626		// TODO: what error does the real Spanner return?
627		return nil, status.Errorf(codes.NotFound, "unknown session %q", req.Session)
628	}
629
630	id := genRandomTransaction()
631	tx := s.db.NewTransaction()
632
633	sess.mu.Lock()
634	sess.lastUse = time.Now()
635	sess.transactions[id] = tx
636	sess.mu.Unlock()
637
638	tr := &spannerpb.Transaction{Id: []byte(id)}
639
640	if req.GetOptions().GetReadOnly().GetReturnReadTimestamp() {
641		// Return the last commit timestamp.
642		// This isn't wholly accurate, but may be good enough for simple use cases.
643		tr.ReadTimestamp = timestampProto(s.db.LastCommitTimestamp())
644	}
645
646	return tr, nil
647}
648
649func (s *server) Commit(ctx context.Context, req *spannerpb.CommitRequest) (resp *spannerpb.CommitResponse, err error) {
650	//s.logf("Commit(%q, %q)", req.Session, req.Transaction)
651
652	obj, ok := req.Transaction.(*spannerpb.CommitRequest_TransactionId)
653	if !ok {
654		return nil, fmt.Errorf("unsupported transaction type %T", req.Transaction)
655	}
656	tid := string(obj.TransactionId)
657
658	tx, err := s.popTx(req.Session, tid)
659	if err != nil {
660		return nil, err
661	}
662	defer func() {
663		if err != nil {
664			tx.Rollback()
665		}
666	}()
667	tx.Start()
668
669	for _, m := range req.Mutations {
670		switch op := m.Operation.(type) {
671		default:
672			return nil, fmt.Errorf("unsupported mutation operation type %T", op)
673		case *spannerpb.Mutation_Insert:
674			ins := op.Insert
675			err := s.db.Insert(tx, ins.Table, ins.Columns, ins.Values)
676			if err != nil {
677				return nil, err
678			}
679		case *spannerpb.Mutation_Update:
680			up := op.Update
681			err := s.db.Update(tx, up.Table, up.Columns, up.Values)
682			if err != nil {
683				return nil, err
684			}
685		case *spannerpb.Mutation_InsertOrUpdate:
686			iou := op.InsertOrUpdate
687			err := s.db.InsertOrUpdate(tx, iou.Table, iou.Columns, iou.Values)
688			if err != nil {
689				return nil, err
690			}
691		case *spannerpb.Mutation_Delete_:
692			del := op.Delete
693			ks := del.KeySet
694
695			err := s.db.Delete(tx, del.Table, ks.Keys, makeKeyRangeList(ks.Ranges), ks.All)
696			if err != nil {
697				return nil, err
698			}
699		}
700
701	}
702
703	ts, err := tx.Commit()
704	if err != nil {
705		return nil, err
706	}
707
708	return &spannerpb.CommitResponse{
709		CommitTimestamp: timestampProto(ts),
710	}, nil
711}
712
713func (s *server) Rollback(ctx context.Context, req *spannerpb.RollbackRequest) (*emptypb.Empty, error) {
714	s.logf("Rollback(%v)", req)
715
716	tx, err := s.popTx(req.Session, string(req.TransactionId))
717	if err != nil {
718		return nil, err
719	}
720
721	tx.Rollback()
722
723	return &emptypb.Empty{}, nil
724}
725
726// TODO: PartitionQuery, PartitionRead
727
728func parseQueryParams(p *structpb.Struct, types map[string]*spannerpb.Type) (queryParams, error) {
729	params := make(queryParams)
730	for k, v := range p.GetFields() {
731		switch v := v.Kind.(type) {
732		default:
733			return nil, fmt.Errorf("unsupported well-known type value kind %T", v)
734		case *structpb.Value_NullValue:
735			params[k] = queryParam{Value: nil} // TODO: set a type?
736		case *structpb.Value_NumberValue:
737			params[k] = queryParam{Value: v.NumberValue, Type: float64Type}
738		case *structpb.Value_StringValue:
739			switch types[k].Code {
740			case spannerpb.TypeCode_INT64:
741				params[k] = queryParam{Value: v.StringValue, Type: int64Type}
742			case spannerpb.TypeCode_TIMESTAMP:
743				params[k] = queryParam{Value: v.StringValue, Type: spansql.Type{Base: spansql.Timestamp}}
744			case spannerpb.TypeCode_DATE:
745				params[k] = queryParam{Value: v.StringValue, Type: spansql.Type{Base: spansql.Date}}
746			case spannerpb.TypeCode_BYTES:
747				b, err := base64.StdEncoding.DecodeString(v.StringValue)
748				if err != nil {
749					return nil, err
750				}
751				params[k] = queryParam{Value: b, Type: spansql.Type{Base: spansql.Bytes, Len: spansql.MaxLen}}
752			default:
753				// All other types represented on the wire as a string are stored internally as strings.
754				// We don't often get a type hint unfortunately, so ths type code here may be wrong.
755				params[k] = queryParam{Value: v.StringValue, Type: stringType}
756			}
757		}
758	}
759	return params, nil
760}
761
762func spannerTypeFromType(typ spansql.Type) (*spannerpb.Type, error) {
763	var code spannerpb.TypeCode
764	switch typ.Base {
765	default:
766		return nil, fmt.Errorf("unhandled base type %d", typ.Base)
767	case spansql.Bool:
768		code = spannerpb.TypeCode_BOOL
769	case spansql.Int64:
770		code = spannerpb.TypeCode_INT64
771	case spansql.Float64:
772		code = spannerpb.TypeCode_FLOAT64
773	case spansql.String:
774		code = spannerpb.TypeCode_STRING
775	case spansql.Bytes:
776		code = spannerpb.TypeCode_BYTES
777	case spansql.Date:
778		code = spannerpb.TypeCode_DATE
779	case spansql.Timestamp:
780		code = spannerpb.TypeCode_TIMESTAMP
781	}
782	st := &spannerpb.Type{Code: code}
783	if typ.Array {
784		st = &spannerpb.Type{
785			Code:             spannerpb.TypeCode_ARRAY,
786			ArrayElementType: st,
787		}
788	}
789	return st, nil
790}
791
792func spannerValueFromValue(x interface{}) (*structpb.Value, error) {
793	switch x := x.(type) {
794	default:
795		return nil, fmt.Errorf("unhandled database value type %T", x)
796	case bool:
797		return &structpb.Value{Kind: &structpb.Value_BoolValue{x}}, nil
798	case int64:
799		// The Spanner int64 is actually a decimal string.
800		s := strconv.FormatInt(x, 10)
801		return &structpb.Value{Kind: &structpb.Value_StringValue{s}}, nil
802	case float64:
803		return &structpb.Value{Kind: &structpb.Value_NumberValue{x}}, nil
804	case string:
805		return &structpb.Value{Kind: &structpb.Value_StringValue{x}}, nil
806	case []byte:
807		return &structpb.Value{Kind: &structpb.Value_StringValue{base64.StdEncoding.EncodeToString(x)}}, nil
808	case nil:
809		return &structpb.Value{Kind: &structpb.Value_NullValue{}}, nil
810	case []interface{}:
811		var vs []*structpb.Value
812		for _, elem := range x {
813			v, err := spannerValueFromValue(elem)
814			if err != nil {
815				return nil, err
816			}
817			vs = append(vs, v)
818		}
819		return &structpb.Value{Kind: &structpb.Value_ListValue{
820			&structpb.ListValue{Values: vs},
821		}}, nil
822	}
823}
824
825func makeKeyRangeList(ranges []*spannerpb.KeyRange) keyRangeList {
826	var krl keyRangeList
827	for _, r := range ranges {
828		krl = append(krl, makeKeyRange(r))
829	}
830	return krl
831}
832
833func makeKeyRange(r *spannerpb.KeyRange) *keyRange {
834	var kr keyRange
835	switch s := r.StartKeyType.(type) {
836	case *spannerpb.KeyRange_StartClosed:
837		kr.start = s.StartClosed
838		kr.startClosed = true
839	case *spannerpb.KeyRange_StartOpen:
840		kr.start = s.StartOpen
841	}
842	switch e := r.EndKeyType.(type) {
843	case *spannerpb.KeyRange_EndClosed:
844		kr.end = e.EndClosed
845		kr.endClosed = true
846	case *spannerpb.KeyRange_EndOpen:
847		kr.end = e.EndOpen
848	}
849	return &kr
850}
851