1/* 2Copyright 2019 Google LLC 3 4Licensed under the Apache License, Version 2.0 (the "License"); 5you may not use this file except in compliance with the License. 6You may obtain a copy of the License at 7 8 http://www.apache.org/licenses/LICENSE-2.0 9 10Unless required by applicable law or agreed to in writing, software 11distributed under the License is distributed on an "AS IS" BASIS, 12WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13See the License for the specific language governing permissions and 14limitations under the License. 15*/ 16 17/* 18Package spannertest contains test helpers for working with Cloud Spanner. 19 20This package is EXPERIMENTAL, and is lacking many features. See the README.md 21file in this directory for more details. 22 23In-memory fake 24 25This package has an in-memory fake implementation of spanner. To use it, 26create a Server, and then connect to it with no security: 27 srv, err := spannertest.NewServer("localhost:0") 28 ... 29 conn, err := grpc.DialContext(ctx, srv.Addr, grpc.WithInsecure()) 30 ... 31 client, err := spanner.NewClient(ctx, db, option.WithGRPCConn(conn)) 32 ... 33 34Alternatively, create a Server, then set the SPANNER_EMULATOR_HOST environment 35variable and use the regular spanner.NewClient: 36 srv, err := spannertest.NewServer("localhost:0") 37 ... 38 os.Setenv("SPANNER_EMULATOR_HOST", srv.Addr) 39 client, err := spanner.NewClient(ctx, db) 40 ... 41 42The same server also supports database admin operations for use with 43the cloud.google.com/go/spanner/admin/database/apiv1 package. This only 44simulates the existence of a single database; its name is ignored. 45*/ 46package spannertest 47 48import ( 49 "context" 50 "encoding/base64" 51 "fmt" 52 "io" 53 "log" 54 "math/rand" 55 "net" 56 "strconv" 57 "sync" 58 "time" 59 60 "github.com/golang/protobuf/proto" 61 "github.com/golang/protobuf/ptypes" 62 "google.golang.org/grpc" 63 "google.golang.org/grpc/codes" 64 "google.golang.org/grpc/status" 65 66 anypb "github.com/golang/protobuf/ptypes/any" 67 emptypb "github.com/golang/protobuf/ptypes/empty" 68 structpb "github.com/golang/protobuf/ptypes/struct" 69 timestamppb "github.com/golang/protobuf/ptypes/timestamp" 70 lropb "google.golang.org/genproto/googleapis/longrunning" 71 adminpb "google.golang.org/genproto/googleapis/spanner/admin/database/v1" 72 spannerpb "google.golang.org/genproto/googleapis/spanner/v1" 73 74 "cloud.google.com/go/civil" 75 "cloud.google.com/go/spanner/spansql" 76) 77 78// Server is an in-memory Cloud Spanner fake. 79// It is unauthenticated, non-performant, and only a rough approximation. 80type Server struct { 81 Addr string 82 83 l net.Listener 84 srv *grpc.Server 85 s *server 86} 87 88// server is the real implementation of the fake. 89// It is a separate and unexported type so the API won't be cluttered with 90// methods that are only relevant to the fake's implementation. 91type server struct { 92 logf Logger 93 94 db database 95 start time.Time 96 97 mu sync.Mutex 98 sessions map[string]*session 99 lros map[string]*lro 100 101 // Any unimplemented methods will cause a panic. 102 // TODO: Switch to Unimplemented at some point? spannerpb would need regenerating. 103 adminpb.DatabaseAdminServer 104 spannerpb.SpannerServer 105 lropb.OperationsServer 106} 107 108type session struct { 109 name string 110 creation time.Time 111 112 // This context tracks the lifetime of this session. 113 // It is canceled in DeleteSession. 114 ctx context.Context 115 cancel func() 116 117 mu sync.Mutex 118 lastUse time.Time 119 transactions map[string]*transaction 120} 121 122func (s *session) Proto() *spannerpb.Session { 123 s.mu.Lock() 124 defer s.mu.Unlock() 125 m := &spannerpb.Session{ 126 Name: s.name, 127 CreateTime: timestampProto(s.creation), 128 ApproximateLastUseTime: timestampProto(s.lastUse), 129 } 130 return m 131} 132 133// timestampProto returns a valid timestamp.Timestamp, 134// or nil if the given time is zero or isn't representable. 135func timestampProto(t time.Time) *timestamppb.Timestamp { 136 if t.IsZero() { 137 return nil 138 } 139 ts, err := ptypes.TimestampProto(t) 140 if err != nil { 141 return nil 142 } 143 return ts 144} 145 146// lro represents a Long-Running Operation, generally a schema change. 147type lro struct { 148 mu sync.Mutex 149 state *lropb.Operation 150 waiters bool // Whether anyone appears to be waiting. 151} 152 153func (l *lro) State() *lropb.Operation { 154 l.mu.Lock() 155 defer l.mu.Unlock() 156 return proto.Clone(l.state).(*lropb.Operation) 157} 158 159// Logger is something that can be used for logging. 160// It is matched by log.Printf and testing.T.Logf. 161type Logger func(format string, args ...interface{}) 162 163// NewServer creates a new Server. 164// The Server will be listening for gRPC connections, without TLS, on the provided TCP address. 165// The resolved address is available in the Addr field. 166func NewServer(laddr string) (*Server, error) { 167 l, err := net.Listen("tcp", laddr) 168 if err != nil { 169 return nil, err 170 } 171 172 s := &Server{ 173 Addr: l.Addr().String(), 174 l: l, 175 srv: grpc.NewServer(), 176 s: &server{ 177 logf: func(format string, args ...interface{}) { 178 log.Printf("spannertest.inmem: "+format, args...) 179 }, 180 start: time.Now(), 181 sessions: make(map[string]*session), 182 lros: make(map[string]*lro), 183 }, 184 } 185 adminpb.RegisterDatabaseAdminServer(s.srv, s.s) 186 spannerpb.RegisterSpannerServer(s.srv, s.s) 187 lropb.RegisterOperationsServer(s.srv, s.s) 188 189 go s.srv.Serve(s.l) 190 191 return s, nil 192} 193 194// SetLogger sets a logger for the server. 195// You can use a *testing.T as this argument to collate extra information 196// from the execution of the server. 197func (s *Server) SetLogger(l Logger) { s.s.logf = l } 198 199// Close shuts down the server. 200func (s *Server) Close() { 201 s.srv.Stop() 202 s.l.Close() 203} 204 205func genRandomSession() string { 206 var b [4]byte 207 rand.Read(b[:]) 208 return fmt.Sprintf("%x", b) 209} 210 211func genRandomTransaction() string { 212 var b [6]byte 213 rand.Read(b[:]) 214 return fmt.Sprintf("tx-%x", b) 215} 216 217func genRandomOperation() string { 218 var b [3]byte 219 rand.Read(b[:]) 220 return fmt.Sprintf("op-%x", b) 221} 222 223func (s *server) GetOperation(ctx context.Context, req *lropb.GetOperationRequest) (*lropb.Operation, error) { 224 s.mu.Lock() 225 lro, ok := s.lros[req.Name] 226 s.mu.Unlock() 227 if !ok { 228 return nil, status.Errorf(codes.NotFound, "unknown LRO %q", req.Name) 229 } 230 231 // Someone is waiting on this LRO. Disable sleeping in its Run method. 232 lro.mu.Lock() 233 lro.waiters = true 234 lro.mu.Unlock() 235 236 return lro.State(), nil 237} 238 239func (s *server) GetDatabase(ctx context.Context, req *adminpb.GetDatabaseRequest) (*adminpb.Database, error) { 240 s.logf("GetDatabase(%q)", req.Name) 241 242 return &adminpb.Database{ 243 Name: req.Name, 244 State: adminpb.Database_READY, 245 CreateTime: timestampProto(s.start), 246 }, nil 247} 248 249// UpdateDDL applies the given DDL to the server. 250// 251// This is a convenience method for tests that may assume an existing schema. 252// The more general approach is to dial this server using an admin client, and 253// use the UpdateDatabaseDdl RPC method. 254func (s *Server) UpdateDDL(ddl *spansql.DDL) error { 255 ctx := context.Background() 256 for _, stmt := range ddl.List { 257 if st := s.s.runOneDDL(ctx, stmt); st.Code() != codes.OK { 258 return st.Err() 259 } 260 } 261 return nil 262} 263 264func (s *server) UpdateDatabaseDdl(ctx context.Context, req *adminpb.UpdateDatabaseDdlRequest) (*lropb.Operation, error) { 265 // Parse all the DDL statements first. 266 var stmts []spansql.DDLStmt 267 for _, s := range req.Statements { 268 stmt, err := spansql.ParseDDLStmt(s) 269 if err != nil { 270 // TODO: check what code the real Spanner returns here. 271 return nil, status.Errorf(codes.InvalidArgument, "bad DDL statement %q: %v", s, err) 272 } 273 stmts = append(stmts, stmt) 274 } 275 276 // Nothing should be depending on the exact structure of this, 277 // but it is specified in google/spanner/admin/database/v1/spanner_database_admin.proto. 278 id := "projects/fake-proj/instances/fake-instance/databases/fake-db/operations/" + genRandomOperation() 279 lro := &lro{ 280 state: &lropb.Operation{ 281 Name: id, 282 }, 283 } 284 s.mu.Lock() 285 s.lros[id] = lro 286 s.mu.Unlock() 287 288 go lro.Run(s, stmts) 289 return lro.State(), nil 290} 291 292func (l *lro) Run(s *server, stmts []spansql.DDLStmt) { 293 ctx := context.Background() 294 295 for _, stmt := range stmts { 296 l.mu.Lock() 297 waiters := l.waiters 298 l.mu.Unlock() 299 if !waiters { 300 // Simulate delayed DDL application, but only if nobody is waiting. 301 time.Sleep(100 * time.Millisecond) 302 } 303 304 if st := s.runOneDDL(ctx, stmt); st.Code() != codes.OK { 305 l.mu.Lock() 306 l.state.Done = true 307 l.state.Result = &lropb.Operation_Error{st.Proto()} 308 l.mu.Unlock() 309 return 310 } 311 } 312 313 l.mu.Lock() 314 l.state.Done = true 315 l.state.Result = &lropb.Operation_Response{&anypb.Any{}} 316 l.mu.Unlock() 317} 318 319func (s *server) runOneDDL(ctx context.Context, stmt spansql.DDLStmt) *status.Status { 320 return s.db.ApplyDDL(stmt) 321} 322 323func (s *server) GetDatabaseDdl(ctx context.Context, req *adminpb.GetDatabaseDdlRequest) (*adminpb.GetDatabaseDdlResponse, error) { 324 s.logf("GetDatabaseDdl(%q)", req.Database) 325 326 var resp adminpb.GetDatabaseDdlResponse 327 for _, stmt := range s.db.GetDDL() { 328 resp.Statements = append(resp.Statements, stmt.SQL()) 329 } 330 return &resp, nil 331} 332 333func (s *server) CreateSession(ctx context.Context, req *spannerpb.CreateSessionRequest) (*spannerpb.Session, error) { 334 //s.logf("CreateSession(%q)", req.Database) 335 return s.newSession(), nil 336} 337 338func (s *server) newSession() *spannerpb.Session { 339 id := genRandomSession() 340 now := time.Now() 341 sess := &session{ 342 name: id, 343 creation: now, 344 lastUse: now, 345 transactions: make(map[string]*transaction), 346 } 347 sess.ctx, sess.cancel = context.WithCancel(context.Background()) 348 349 s.mu.Lock() 350 s.sessions[id] = sess 351 s.mu.Unlock() 352 353 return sess.Proto() 354} 355 356func (s *server) BatchCreateSessions(ctx context.Context, req *spannerpb.BatchCreateSessionsRequest) (*spannerpb.BatchCreateSessionsResponse, error) { 357 //s.logf("BatchCreateSessions(%q)", req.Database) 358 359 var sessions []*spannerpb.Session 360 for i := int32(0); i < req.GetSessionCount(); i++ { 361 sessions = append(sessions, s.newSession()) 362 } 363 364 return &spannerpb.BatchCreateSessionsResponse{Session: sessions}, nil 365} 366 367func (s *server) GetSession(ctx context.Context, req *spannerpb.GetSessionRequest) (*spannerpb.Session, error) { 368 s.mu.Lock() 369 sess, ok := s.sessions[req.Name] 370 s.mu.Unlock() 371 372 if !ok { 373 // TODO: what error does the real Spanner return? 374 return nil, status.Errorf(codes.NotFound, "unknown session %q", req.Name) 375 } 376 377 return sess.Proto(), nil 378} 379 380// TODO: ListSessions 381 382func (s *server) DeleteSession(ctx context.Context, req *spannerpb.DeleteSessionRequest) (*emptypb.Empty, error) { 383 //s.logf("DeleteSession(%q)", req.Name) 384 385 s.mu.Lock() 386 sess, ok := s.sessions[req.Name] 387 delete(s.sessions, req.Name) 388 s.mu.Unlock() 389 390 if !ok { 391 // TODO: what error does the real Spanner return? 392 return nil, status.Errorf(codes.NotFound, "unknown session %q", req.Name) 393 } 394 395 // Terminate any operations in this session. 396 sess.cancel() 397 398 return &emptypb.Empty{}, nil 399} 400 401// popTx returns an existing transaction, removing it from the session. 402// This is called when a transaction is finishing (Commit, Rollback). 403func (s *server) popTx(sessionID, tid string) (tx *transaction, err error) { 404 s.mu.Lock() 405 sess, ok := s.sessions[sessionID] 406 s.mu.Unlock() 407 if !ok { 408 // TODO: what error does the real Spanner return? 409 return nil, status.Errorf(codes.NotFound, "unknown session %q", sessionID) 410 } 411 412 sess.mu.Lock() 413 sess.lastUse = time.Now() 414 tx, ok = sess.transactions[tid] 415 if ok { 416 delete(sess.transactions, tid) 417 } 418 sess.mu.Unlock() 419 if !ok { 420 // TODO: what error does the real Spanner return? 421 return nil, status.Errorf(codes.NotFound, "unknown transaction ID %q", tid) 422 } 423 return tx, nil 424} 425 426// readTx returns a transaction for the given session and transaction selector. 427// It is used by read/query operations (ExecuteStreamingSql, StreamingRead). 428func (s *server) readTx(ctx context.Context, session string, tsel *spannerpb.TransactionSelector) (tx *transaction, cleanup func(), err error) { 429 s.mu.Lock() 430 sess, ok := s.sessions[session] 431 s.mu.Unlock() 432 if !ok { 433 // TODO: what error does the real Spanner return? 434 return nil, nil, status.Errorf(codes.NotFound, "unknown session %q", session) 435 } 436 437 sess.mu.Lock() 438 sess.lastUse = time.Now() 439 sess.mu.Unlock() 440 441 // Only give a read-only transaction regardless of whether the selector 442 // is requesting a read-write or read-only one, since this is in readTx 443 // and so shouldn't be mutating anyway. 444 singleUse := func() (*transaction, func(), error) { 445 tx := s.db.NewReadOnlyTransaction() 446 return tx, tx.Rollback, nil 447 } 448 449 if tsel.GetSelector() == nil { 450 return singleUse() 451 } 452 453 switch sel := tsel.Selector.(type) { 454 default: 455 return nil, nil, fmt.Errorf("TransactionSelector type %T not supported", sel) 456 case *spannerpb.TransactionSelector_SingleUse: 457 // Ignore options (e.g. timestamps). 458 switch mode := sel.SingleUse.Mode.(type) { 459 case *spannerpb.TransactionOptions_ReadOnly_: 460 return singleUse() 461 case *spannerpb.TransactionOptions_ReadWrite_: 462 return singleUse() 463 default: 464 return nil, nil, fmt.Errorf("single use transaction in mode %T not supported", mode) 465 } 466 case *spannerpb.TransactionSelector_Id: 467 sess.mu.Lock() 468 tx, ok := sess.transactions[string(sel.Id)] 469 sess.mu.Unlock() 470 if !ok { 471 return nil, nil, fmt.Errorf("no transaction with id %q", sel.Id) 472 } 473 return tx, func() {}, nil 474 } 475} 476 477func (s *server) ExecuteSql(ctx context.Context, req *spannerpb.ExecuteSqlRequest) (*spannerpb.ResultSet, error) { 478 // Assume this is probably a DML statement. Queries tend to use ExecuteStreamingSql. 479 // TODO: Expand this to support more things. 480 481 obj, ok := req.Transaction.Selector.(*spannerpb.TransactionSelector_Id) 482 if !ok { 483 return nil, fmt.Errorf("unsupported transaction type %T", req.Transaction.Selector) 484 } 485 tid := string(obj.Id) 486 _ = tid // TODO: lookup an existing transaction by ID. 487 488 stmt, err := spansql.ParseDMLStmt(req.Sql) 489 if err != nil { 490 return nil, status.Errorf(codes.InvalidArgument, "bad DML: %v", err) 491 } 492 params, err := parseQueryParams(req.GetParams(), req.ParamTypes) 493 if err != nil { 494 return nil, err 495 } 496 497 s.logf("Executing: %s", stmt.SQL()) 498 if len(params) > 0 { 499 s.logf(" ▹ %v", params) 500 } 501 502 n, err := s.db.Execute(stmt, params) 503 if err != nil { 504 return nil, err 505 } 506 return &spannerpb.ResultSet{ 507 Stats: &spannerpb.ResultSetStats{ 508 RowCount: &spannerpb.ResultSetStats_RowCountExact{int64(n)}, 509 }, 510 }, nil 511} 512 513func (s *server) ExecuteStreamingSql(req *spannerpb.ExecuteSqlRequest, stream spannerpb.Spanner_ExecuteStreamingSqlServer) error { 514 tx, cleanup, err := s.readTx(stream.Context(), req.Session, req.Transaction) 515 if err != nil { 516 return err 517 } 518 defer cleanup() 519 520 q, err := spansql.ParseQuery(req.Sql) 521 if err != nil { 522 // TODO: check what code the real Spanner returns here. 523 return status.Errorf(codes.InvalidArgument, "bad query: %v", err) 524 } 525 526 params, err := parseQueryParams(req.GetParams(), req.ParamTypes) 527 if err != nil { 528 return err 529 } 530 531 s.logf("Querying: %s", q.SQL()) 532 if len(params) > 0 { 533 s.logf(" ▹ %v", params) 534 } 535 536 ri, err := s.db.Query(q, params) 537 if err != nil { 538 return err 539 } 540 return s.readStream(stream.Context(), tx, stream.Send, ri) 541} 542 543// TODO: Read 544 545func (s *server) StreamingRead(req *spannerpb.ReadRequest, stream spannerpb.Spanner_StreamingReadServer) error { 546 tx, cleanup, err := s.readTx(stream.Context(), req.Session, req.Transaction) 547 if err != nil { 548 return err 549 } 550 defer cleanup() 551 552 // Bail out if various advanced features are being used. 553 if req.Index != "" { 554 // This is okay; we can still return results. 555 s.logf("Warning: index reads (%q) not supported", req.Index) 556 } 557 if len(req.ResumeToken) > 0 { 558 // This should only happen if we send resume_token ourselves. 559 return fmt.Errorf("read resumption not supported") 560 } 561 if len(req.PartitionToken) > 0 { 562 return fmt.Errorf("partition restrictions not supported") 563 } 564 565 var ri rowIter 566 if req.KeySet.All { 567 s.logf("Reading all from %s (cols: %v)", req.Table, req.Columns) 568 ri, err = s.db.ReadAll(spansql.ID(req.Table), idList(req.Columns), req.Limit) 569 } else { 570 s.logf("Reading rows from %d keys and %d ranges from %s (cols: %v)", len(req.KeySet.Keys), len(req.KeySet.Ranges), req.Table, req.Columns) 571 ri, err = s.db.Read(spansql.ID(req.Table), idList(req.Columns), req.KeySet.Keys, makeKeyRangeList(req.KeySet.Ranges), req.Limit) 572 } 573 if err != nil { 574 return err 575 } 576 577 // TODO: Figure out the right contexts to use here. There's the session one (sess.ctx), 578 // but also this specific RPC one (stream.Context()). Which takes precedence? 579 // They appear to be independent. 580 581 return s.readStream(stream.Context(), tx, stream.Send, ri) 582} 583 584func (s *server) readStream(ctx context.Context, tx *transaction, send func(*spannerpb.PartialResultSet) error, ri rowIter) error { 585 // Build the result set metadata. 586 rsm := &spannerpb.ResultSetMetadata{ 587 RowType: &spannerpb.StructType{}, 588 // TODO: transaction info? 589 } 590 for _, ci := range ri.Cols() { 591 st, err := spannerTypeFromType(ci.Type) 592 if err != nil { 593 return err 594 } 595 rsm.RowType.Fields = append(rsm.RowType.Fields, &spannerpb.StructType_Field{ 596 Name: string(ci.Name), 597 Type: st, 598 }) 599 } 600 601 for { 602 row, err := ri.Next() 603 if err == io.EOF { 604 break 605 } else if err != nil { 606 return err 607 } 608 609 values := make([]*structpb.Value, len(row)) 610 for i, x := range row { 611 v, err := spannerValueFromValue(x) 612 if err != nil { 613 return err 614 } 615 values[i] = v 616 } 617 618 prs := &spannerpb.PartialResultSet{ 619 Metadata: rsm, 620 Values: values, 621 } 622 if err := send(prs); err != nil { 623 return err 624 } 625 626 // ResultSetMetadata is only set for the first PartialResultSet. 627 rsm = nil 628 } 629 630 return nil 631} 632 633func (s *server) BeginTransaction(ctx context.Context, req *spannerpb.BeginTransactionRequest) (*spannerpb.Transaction, error) { 634 //s.logf("BeginTransaction(%v)", req) 635 636 s.mu.Lock() 637 sess, ok := s.sessions[req.Session] 638 s.mu.Unlock() 639 if !ok { 640 // TODO: what error does the real Spanner return? 641 return nil, status.Errorf(codes.NotFound, "unknown session %q", req.Session) 642 } 643 644 id := genRandomTransaction() 645 tx := s.db.NewTransaction() 646 647 sess.mu.Lock() 648 sess.lastUse = time.Now() 649 sess.transactions[id] = tx 650 sess.mu.Unlock() 651 652 tr := &spannerpb.Transaction{Id: []byte(id)} 653 654 if req.GetOptions().GetReadOnly().GetReturnReadTimestamp() { 655 // Return the last commit timestamp. 656 // This isn't wholly accurate, but may be good enough for simple use cases. 657 tr.ReadTimestamp = timestampProto(s.db.LastCommitTimestamp()) 658 } 659 660 return tr, nil 661} 662 663func (s *server) Commit(ctx context.Context, req *spannerpb.CommitRequest) (resp *spannerpb.CommitResponse, err error) { 664 //s.logf("Commit(%q, %q)", req.Session, req.Transaction) 665 666 obj, ok := req.Transaction.(*spannerpb.CommitRequest_TransactionId) 667 if !ok { 668 return nil, fmt.Errorf("unsupported transaction type %T", req.Transaction) 669 } 670 tid := string(obj.TransactionId) 671 672 tx, err := s.popTx(req.Session, tid) 673 if err != nil { 674 return nil, err 675 } 676 defer func() { 677 if err != nil { 678 tx.Rollback() 679 } 680 }() 681 tx.Start() 682 683 for _, m := range req.Mutations { 684 switch op := m.Operation.(type) { 685 default: 686 return nil, fmt.Errorf("unsupported mutation operation type %T", op) 687 case *spannerpb.Mutation_Insert: 688 ins := op.Insert 689 err := s.db.Insert(tx, spansql.ID(ins.Table), idList(ins.Columns), ins.Values) 690 if err != nil { 691 return nil, err 692 } 693 case *spannerpb.Mutation_Update: 694 up := op.Update 695 err := s.db.Update(tx, spansql.ID(up.Table), idList(up.Columns), up.Values) 696 if err != nil { 697 return nil, err 698 } 699 case *spannerpb.Mutation_InsertOrUpdate: 700 iou := op.InsertOrUpdate 701 err := s.db.InsertOrUpdate(tx, spansql.ID(iou.Table), idList(iou.Columns), iou.Values) 702 if err != nil { 703 return nil, err 704 } 705 case *spannerpb.Mutation_Delete_: 706 del := op.Delete 707 ks := del.KeySet 708 709 err := s.db.Delete(tx, spansql.ID(del.Table), ks.Keys, makeKeyRangeList(ks.Ranges), ks.All) 710 if err != nil { 711 return nil, err 712 } 713 } 714 715 } 716 717 ts, err := tx.Commit() 718 if err != nil { 719 return nil, err 720 } 721 722 return &spannerpb.CommitResponse{ 723 CommitTimestamp: timestampProto(ts), 724 }, nil 725} 726 727func (s *server) Rollback(ctx context.Context, req *spannerpb.RollbackRequest) (*emptypb.Empty, error) { 728 s.logf("Rollback(%v)", req) 729 730 tx, err := s.popTx(req.Session, string(req.TransactionId)) 731 if err != nil { 732 return nil, err 733 } 734 735 tx.Rollback() 736 737 return &emptypb.Empty{}, nil 738} 739 740// TODO: PartitionQuery, PartitionRead 741 742func parseQueryParams(p *structpb.Struct, types map[string]*spannerpb.Type) (queryParams, error) { 743 params := make(queryParams) 744 for k, v := range p.GetFields() { 745 p, err := parseQueryParam(v, types[k]) 746 if err != nil { 747 return nil, err 748 } 749 params[k] = p 750 } 751 return params, nil 752} 753 754func parseQueryParam(v *structpb.Value, typ *spannerpb.Type) (queryParam, error) { 755 // TODO: Use valForType and typeFromSpannerType more comprehensively here? 756 // They are only used for StringValue vs, since that's what mostly needs parsing. 757 758 rawv := v 759 switch v := v.Kind.(type) { 760 default: 761 return queryParam{}, fmt.Errorf("unsupported well-known type value kind %T", v) 762 case *structpb.Value_NullValue: 763 return queryParam{Value: nil}, nil // TODO: set a type? 764 case *structpb.Value_BoolValue: 765 return queryParam{Value: v.BoolValue, Type: boolType}, nil 766 case *structpb.Value_NumberValue: 767 return queryParam{Value: v.NumberValue, Type: float64Type}, nil 768 case *structpb.Value_StringValue: 769 t, err := typeFromSpannerType(typ) 770 if err != nil { 771 return queryParam{}, err 772 } 773 val, err := valForType(rawv, t) 774 if err != nil { 775 return queryParam{}, err 776 } 777 return queryParam{Value: val, Type: t}, nil 778 case *structpb.Value_ListValue: 779 var list []interface{} 780 for _, elem := range v.ListValue.Values { 781 // TODO: Change the type parameter passed through? We only look at the code. 782 p, err := parseQueryParam(elem, typ) 783 if err != nil { 784 return queryParam{}, err 785 } 786 list = append(list, p.Value) 787 } 788 t, err := typeFromSpannerType(typ) 789 if err != nil { 790 return queryParam{}, err 791 } 792 return queryParam{Value: list, Type: t}, nil 793 } 794} 795 796func typeFromSpannerType(st *spannerpb.Type) (spansql.Type, error) { 797 switch st.Code { 798 default: 799 return spansql.Type{}, fmt.Errorf("unhandled spanner type code %v", st.Code) 800 case spannerpb.TypeCode_BOOL: 801 return spansql.Type{Base: spansql.Bool}, nil 802 case spannerpb.TypeCode_INT64: 803 return spansql.Type{Base: spansql.Int64}, nil 804 case spannerpb.TypeCode_FLOAT64: 805 return spansql.Type{Base: spansql.Float64}, nil 806 case spannerpb.TypeCode_TIMESTAMP: 807 return spansql.Type{Base: spansql.Timestamp}, nil 808 case spannerpb.TypeCode_DATE: 809 return spansql.Type{Base: spansql.Date}, nil 810 case spannerpb.TypeCode_STRING: 811 return spansql.Type{Base: spansql.String}, nil // no len 812 case spannerpb.TypeCode_BYTES: 813 return spansql.Type{Base: spansql.Bytes}, nil // no len 814 case spannerpb.TypeCode_ARRAY: 815 typ, err := typeFromSpannerType(st.ArrayElementType) 816 if err != nil { 817 return spansql.Type{}, err 818 } 819 typ.Array = true 820 return typ, nil 821 } 822} 823 824func spannerTypeFromType(typ spansql.Type) (*spannerpb.Type, error) { 825 var code spannerpb.TypeCode 826 switch typ.Base { 827 default: 828 return nil, fmt.Errorf("unhandled base type %d", typ.Base) 829 case spansql.Bool: 830 code = spannerpb.TypeCode_BOOL 831 case spansql.Int64: 832 code = spannerpb.TypeCode_INT64 833 case spansql.Float64: 834 code = spannerpb.TypeCode_FLOAT64 835 case spansql.String: 836 code = spannerpb.TypeCode_STRING 837 case spansql.Bytes: 838 code = spannerpb.TypeCode_BYTES 839 case spansql.Date: 840 code = spannerpb.TypeCode_DATE 841 case spansql.Timestamp: 842 code = spannerpb.TypeCode_TIMESTAMP 843 } 844 st := &spannerpb.Type{Code: code} 845 if typ.Array { 846 st = &spannerpb.Type{ 847 Code: spannerpb.TypeCode_ARRAY, 848 ArrayElementType: st, 849 } 850 } 851 return st, nil 852} 853 854func spannerValueFromValue(x interface{}) (*structpb.Value, error) { 855 switch x := x.(type) { 856 default: 857 return nil, fmt.Errorf("unhandled database value type %T", x) 858 case bool: 859 return &structpb.Value{Kind: &structpb.Value_BoolValue{x}}, nil 860 case int64: 861 // The Spanner int64 is actually a decimal string. 862 s := strconv.FormatInt(x, 10) 863 return &structpb.Value{Kind: &structpb.Value_StringValue{s}}, nil 864 case float64: 865 return &structpb.Value{Kind: &structpb.Value_NumberValue{x}}, nil 866 case string: 867 return &structpb.Value{Kind: &structpb.Value_StringValue{x}}, nil 868 case []byte: 869 return &structpb.Value{Kind: &structpb.Value_StringValue{base64.StdEncoding.EncodeToString(x)}}, nil 870 case civil.Date: 871 // RFC 3339 date format. 872 return &structpb.Value{Kind: &structpb.Value_StringValue{x.String()}}, nil 873 case time.Time: 874 // RFC 3339 timestamp format with zone Z. 875 s := x.Format("2006-01-02T15:04:05.999999999Z") 876 return &structpb.Value{Kind: &structpb.Value_StringValue{s}}, nil 877 case nil: 878 return &structpb.Value{Kind: &structpb.Value_NullValue{}}, nil 879 case []interface{}: 880 var vs []*structpb.Value 881 for _, elem := range x { 882 v, err := spannerValueFromValue(elem) 883 if err != nil { 884 return nil, err 885 } 886 vs = append(vs, v) 887 } 888 return &structpb.Value{Kind: &structpb.Value_ListValue{ 889 &structpb.ListValue{Values: vs}, 890 }}, nil 891 } 892} 893 894func makeKeyRangeList(ranges []*spannerpb.KeyRange) keyRangeList { 895 var krl keyRangeList 896 for _, r := range ranges { 897 krl = append(krl, makeKeyRange(r)) 898 } 899 return krl 900} 901 902func makeKeyRange(r *spannerpb.KeyRange) *keyRange { 903 var kr keyRange 904 switch s := r.StartKeyType.(type) { 905 case *spannerpb.KeyRange_StartClosed: 906 kr.start = s.StartClosed 907 kr.startClosed = true 908 case *spannerpb.KeyRange_StartOpen: 909 kr.start = s.StartOpen 910 } 911 switch e := r.EndKeyType.(type) { 912 case *spannerpb.KeyRange_EndClosed: 913 kr.end = e.EndClosed 914 kr.endClosed = true 915 case *spannerpb.KeyRange_EndOpen: 916 kr.end = e.EndOpen 917 } 918 return &kr 919} 920 921func idList(ss []string) (ids []spansql.ID) { 922 for _, s := range ss { 923 ids = append(ids, spansql.ID(s)) 924 } 925 return 926} 927