1/* 2Copyright 2019 Google LLC 3 4Licensed under the Apache License, Version 2.0 (the "License"); 5you may not use this file except in compliance with the License. 6You may obtain a copy of the License at 7 8 http://www.apache.org/licenses/LICENSE-2.0 9 10Unless required by applicable law or agreed to in writing, software 11distributed under the License is distributed on an "AS IS" BASIS, 12WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13See the License for the specific language governing permissions and 14limitations under the License. 15*/ 16 17/* 18Package spannertest contains test helpers for working with Cloud Spanner. 19 20This package is EXPERIMENTAL, and is lacking several features. See the README.md 21file in this directory for more details. 22 23In-memory fake 24 25This package has an in-memory fake implementation of spanner. To use it, 26create a Server, and then connect to it with no security: 27 srv, err := spannertest.NewServer("localhost:0") 28 ... 29 conn, err := grpc.DialContext(ctx, srv.Addr, grpc.WithInsecure()) 30 ... 31 client, err := spanner.NewClient(ctx, db, option.WithGRPCConn(conn)) 32 ... 33 34Alternatively, create a Server, then set the SPANNER_EMULATOR_HOST environment 35variable and use the regular spanner.NewClient: 36 srv, err := spannertest.NewServer("localhost:0") 37 ... 38 os.Setenv("SPANNER_EMULATOR_HOST", srv.Addr) 39 client, err := spanner.NewClient(ctx, db) 40 ... 41 42The same server also supports database admin operations for use with 43the cloud.google.com/go/spanner/admin/database/apiv1 package. This only 44simulates the existence of a single database; its name is ignored. 45*/ 46package spannertest 47 48import ( 49 "context" 50 "encoding/base64" 51 "fmt" 52 "io" 53 "log" 54 "math/rand" 55 "net" 56 "strconv" 57 "sync" 58 "sync/atomic" 59 "time" 60 61 "github.com/golang/protobuf/proto" 62 "github.com/golang/protobuf/ptypes" 63 "google.golang.org/grpc" 64 "google.golang.org/grpc/codes" 65 "google.golang.org/grpc/status" 66 67 anypb "github.com/golang/protobuf/ptypes/any" 68 emptypb "github.com/golang/protobuf/ptypes/empty" 69 structpb "github.com/golang/protobuf/ptypes/struct" 70 timestamppb "github.com/golang/protobuf/ptypes/timestamp" 71 lropb "google.golang.org/genproto/googleapis/longrunning" 72 adminpb "google.golang.org/genproto/googleapis/spanner/admin/database/v1" 73 spannerpb "google.golang.org/genproto/googleapis/spanner/v1" 74 75 "cloud.google.com/go/civil" 76 "cloud.google.com/go/spanner/spansql" 77) 78 79// Server is an in-memory Cloud Spanner fake. 80// It is unauthenticated, non-performant, and only a rough approximation. 81type Server struct { 82 Addr string 83 84 l net.Listener 85 srv *grpc.Server 86 s *server 87} 88 89// server is the real implementation of the fake. 90// It is a separate and unexported type so the API won't be cluttered with 91// methods that are only relevant to the fake's implementation. 92type server struct { 93 logf Logger 94 95 db database 96 start time.Time 97 98 mu sync.Mutex 99 sessions map[string]*session 100 lros map[string]*lro 101 102 // Any unimplemented methods will cause a panic. 103 // TODO: Switch to Unimplemented at some point? spannerpb would need regenerating. 104 adminpb.DatabaseAdminServer 105 spannerpb.SpannerServer 106 lropb.OperationsServer 107} 108 109type session struct { 110 name string 111 creation time.Time 112 113 // This context tracks the lifetime of this session. 114 // It is canceled in DeleteSession. 115 ctx context.Context 116 cancel func() 117 118 mu sync.Mutex 119 lastUse time.Time 120 transactions map[string]*transaction 121} 122 123func (s *session) Proto() *spannerpb.Session { 124 s.mu.Lock() 125 defer s.mu.Unlock() 126 m := &spannerpb.Session{ 127 Name: s.name, 128 CreateTime: timestampProto(s.creation), 129 ApproximateLastUseTime: timestampProto(s.lastUse), 130 } 131 return m 132} 133 134// timestampProto returns a valid timestamp.Timestamp, 135// or nil if the given time is zero or isn't representable. 136func timestampProto(t time.Time) *timestamppb.Timestamp { 137 if t.IsZero() { 138 return nil 139 } 140 ts, err := ptypes.TimestampProto(t) 141 if err != nil { 142 return nil 143 } 144 return ts 145} 146 147// lro represents a Long-Running Operation, generally a schema change. 148type lro struct { 149 mu sync.Mutex 150 state *lropb.Operation 151 152 // waitc is closed when anyone starts waiting on the LRO. 153 // waitatom is CAS'd from 0 to 1 to make that closing safe. 154 waitc chan struct{} 155 waitatom int32 156} 157 158func newLRO(initState *lropb.Operation) *lro { 159 return &lro{ 160 state: initState, 161 waitc: make(chan struct{}), 162 } 163} 164 165func (l *lro) noWait() { 166 if atomic.CompareAndSwapInt32(&l.waitatom, 0, 1) { 167 close(l.waitc) 168 } 169} 170 171func (l *lro) State() *lropb.Operation { 172 l.mu.Lock() 173 defer l.mu.Unlock() 174 return proto.Clone(l.state).(*lropb.Operation) 175} 176 177// Logger is something that can be used for logging. 178// It is matched by log.Printf and testing.T.Logf. 179type Logger func(format string, args ...interface{}) 180 181// NewServer creates a new Server. 182// The Server will be listening for gRPC connections, without TLS, on the provided TCP address. 183// The resolved address is available in the Addr field. 184func NewServer(laddr string) (*Server, error) { 185 l, err := net.Listen("tcp", laddr) 186 if err != nil { 187 return nil, err 188 } 189 190 s := &Server{ 191 Addr: l.Addr().String(), 192 l: l, 193 srv: grpc.NewServer(), 194 s: &server{ 195 logf: func(format string, args ...interface{}) { 196 log.Printf("spannertest.inmem: "+format, args...) 197 }, 198 start: time.Now(), 199 sessions: make(map[string]*session), 200 lros: make(map[string]*lro), 201 }, 202 } 203 adminpb.RegisterDatabaseAdminServer(s.srv, s.s) 204 spannerpb.RegisterSpannerServer(s.srv, s.s) 205 lropb.RegisterOperationsServer(s.srv, s.s) 206 207 go s.srv.Serve(s.l) 208 209 return s, nil 210} 211 212// SetLogger sets a logger for the server. 213// You can use a *testing.T as this argument to collate extra information 214// from the execution of the server. 215func (s *Server) SetLogger(l Logger) { s.s.logf = l } 216 217// Close shuts down the server. 218func (s *Server) Close() { 219 s.srv.Stop() 220 s.l.Close() 221} 222 223func genRandomSession() string { 224 var b [4]byte 225 rand.Read(b[:]) 226 return fmt.Sprintf("%x", b) 227} 228 229func genRandomTransaction() string { 230 var b [6]byte 231 rand.Read(b[:]) 232 return fmt.Sprintf("tx-%x", b) 233} 234 235func genRandomOperation() string { 236 var b [3]byte 237 rand.Read(b[:]) 238 return fmt.Sprintf("op-%x", b) 239} 240 241func (s *server) GetOperation(ctx context.Context, req *lropb.GetOperationRequest) (*lropb.Operation, error) { 242 s.mu.Lock() 243 lro, ok := s.lros[req.Name] 244 s.mu.Unlock() 245 if !ok { 246 return nil, status.Errorf(codes.NotFound, "unknown LRO %q", req.Name) 247 } 248 249 // Someone is waiting on this LRO. Disable sleeping in its Run method. 250 lro.noWait() 251 252 return lro.State(), nil 253} 254 255func (s *server) GetDatabase(ctx context.Context, req *adminpb.GetDatabaseRequest) (*adminpb.Database, error) { 256 s.logf("GetDatabase(%q)", req.Name) 257 258 return &adminpb.Database{ 259 Name: req.Name, 260 State: adminpb.Database_READY, 261 CreateTime: timestampProto(s.start), 262 }, nil 263} 264 265// UpdateDDL applies the given DDL to the server. 266// 267// This is a convenience method for tests that may assume an existing schema. 268// The more general approach is to dial this server using an admin client, and 269// use the UpdateDatabaseDdl RPC method. 270func (s *Server) UpdateDDL(ddl *spansql.DDL) error { 271 ctx := context.Background() 272 for _, stmt := range ddl.List { 273 if st := s.s.runOneDDL(ctx, stmt); st.Code() != codes.OK { 274 return st.Err() 275 } 276 } 277 return nil 278} 279 280func (s *server) UpdateDatabaseDdl(ctx context.Context, req *adminpb.UpdateDatabaseDdlRequest) (*lropb.Operation, error) { 281 // Parse all the DDL statements first. 282 var stmts []spansql.DDLStmt 283 for _, s := range req.Statements { 284 stmt, err := spansql.ParseDDLStmt(s) 285 if err != nil { 286 // TODO: check what code the real Spanner returns here. 287 return nil, status.Errorf(codes.InvalidArgument, "bad DDL statement %q: %v", s, err) 288 } 289 stmts = append(stmts, stmt) 290 } 291 292 // Nothing should be depending on the exact structure of this, 293 // but it is specified in google/spanner/admin/database/v1/spanner_database_admin.proto. 294 id := "projects/fake-proj/instances/fake-instance/databases/fake-db/operations/" + genRandomOperation() 295 lro := newLRO(&lropb.Operation{Name: id}) 296 s.mu.Lock() 297 s.lros[id] = lro 298 s.mu.Unlock() 299 300 go lro.Run(s, stmts) 301 return lro.State(), nil 302} 303 304func (l *lro) Run(s *server, stmts []spansql.DDLStmt) { 305 ctx := context.Background() 306 307 for _, stmt := range stmts { 308 // Simulate delayed DDL application, but only if nobody is waiting. 309 select { 310 case <-time.After(100 * time.Millisecond): 311 case <-l.waitc: 312 } 313 314 if st := s.runOneDDL(ctx, stmt); st.Code() != codes.OK { 315 l.mu.Lock() 316 l.state.Done = true 317 l.state.Result = &lropb.Operation_Error{st.Proto()} 318 l.mu.Unlock() 319 return 320 } 321 } 322 323 l.mu.Lock() 324 l.state.Done = true 325 l.state.Result = &lropb.Operation_Response{&anypb.Any{}} 326 l.mu.Unlock() 327} 328 329func (s *server) runOneDDL(ctx context.Context, stmt spansql.DDLStmt) *status.Status { 330 return s.db.ApplyDDL(stmt) 331} 332 333func (s *server) GetDatabaseDdl(ctx context.Context, req *adminpb.GetDatabaseDdlRequest) (*adminpb.GetDatabaseDdlResponse, error) { 334 s.logf("GetDatabaseDdl(%q)", req.Database) 335 336 var resp adminpb.GetDatabaseDdlResponse 337 for _, stmt := range s.db.GetDDL() { 338 resp.Statements = append(resp.Statements, stmt.SQL()) 339 } 340 return &resp, nil 341} 342 343func (s *server) CreateSession(ctx context.Context, req *spannerpb.CreateSessionRequest) (*spannerpb.Session, error) { 344 //s.logf("CreateSession(%q)", req.Database) 345 return s.newSession(), nil 346} 347 348func (s *server) newSession() *spannerpb.Session { 349 id := genRandomSession() 350 now := time.Now() 351 sess := &session{ 352 name: id, 353 creation: now, 354 lastUse: now, 355 transactions: make(map[string]*transaction), 356 } 357 sess.ctx, sess.cancel = context.WithCancel(context.Background()) 358 359 s.mu.Lock() 360 s.sessions[id] = sess 361 s.mu.Unlock() 362 363 return sess.Proto() 364} 365 366func (s *server) BatchCreateSessions(ctx context.Context, req *spannerpb.BatchCreateSessionsRequest) (*spannerpb.BatchCreateSessionsResponse, error) { 367 //s.logf("BatchCreateSessions(%q)", req.Database) 368 369 var sessions []*spannerpb.Session 370 for i := int32(0); i < req.GetSessionCount(); i++ { 371 sessions = append(sessions, s.newSession()) 372 } 373 374 return &spannerpb.BatchCreateSessionsResponse{Session: sessions}, nil 375} 376 377func (s *server) GetSession(ctx context.Context, req *spannerpb.GetSessionRequest) (*spannerpb.Session, error) { 378 s.mu.Lock() 379 sess, ok := s.sessions[req.Name] 380 s.mu.Unlock() 381 382 if !ok { 383 // TODO: what error does the real Spanner return? 384 return nil, status.Errorf(codes.NotFound, "unknown session %q", req.Name) 385 } 386 387 return sess.Proto(), nil 388} 389 390// TODO: ListSessions 391 392func (s *server) DeleteSession(ctx context.Context, req *spannerpb.DeleteSessionRequest) (*emptypb.Empty, error) { 393 //s.logf("DeleteSession(%q)", req.Name) 394 395 s.mu.Lock() 396 sess, ok := s.sessions[req.Name] 397 delete(s.sessions, req.Name) 398 s.mu.Unlock() 399 400 if !ok { 401 // TODO: what error does the real Spanner return? 402 return nil, status.Errorf(codes.NotFound, "unknown session %q", req.Name) 403 } 404 405 // Terminate any operations in this session. 406 sess.cancel() 407 408 return &emptypb.Empty{}, nil 409} 410 411// popTx returns an existing transaction, removing it from the session. 412// This is called when a transaction is finishing (Commit, Rollback). 413func (s *server) popTx(sessionID, tid string) (tx *transaction, err error) { 414 s.mu.Lock() 415 sess, ok := s.sessions[sessionID] 416 s.mu.Unlock() 417 if !ok { 418 // TODO: what error does the real Spanner return? 419 return nil, status.Errorf(codes.NotFound, "unknown session %q", sessionID) 420 } 421 422 sess.mu.Lock() 423 sess.lastUse = time.Now() 424 tx, ok = sess.transactions[tid] 425 if ok { 426 delete(sess.transactions, tid) 427 } 428 sess.mu.Unlock() 429 if !ok { 430 // TODO: what error does the real Spanner return? 431 return nil, status.Errorf(codes.NotFound, "unknown transaction ID %q", tid) 432 } 433 return tx, nil 434} 435 436// readTx returns a transaction for the given session and transaction selector. 437// It is used by read/query operations (ExecuteStreamingSql, StreamingRead). 438func (s *server) readTx(ctx context.Context, session string, tsel *spannerpb.TransactionSelector) (tx *transaction, cleanup func(), err error) { 439 s.mu.Lock() 440 sess, ok := s.sessions[session] 441 s.mu.Unlock() 442 if !ok { 443 // TODO: what error does the real Spanner return? 444 return nil, nil, status.Errorf(codes.NotFound, "unknown session %q", session) 445 } 446 447 sess.mu.Lock() 448 sess.lastUse = time.Now() 449 sess.mu.Unlock() 450 451 // Only give a read-only transaction regardless of whether the selector 452 // is requesting a read-write or read-only one, since this is in readTx 453 // and so shouldn't be mutating anyway. 454 singleUse := func() (*transaction, func(), error) { 455 tx := s.db.NewReadOnlyTransaction() 456 return tx, tx.Rollback, nil 457 } 458 459 if tsel.GetSelector() == nil { 460 return singleUse() 461 } 462 463 switch sel := tsel.Selector.(type) { 464 default: 465 return nil, nil, fmt.Errorf("TransactionSelector type %T not supported", sel) 466 case *spannerpb.TransactionSelector_SingleUse: 467 // Ignore options (e.g. timestamps). 468 switch mode := sel.SingleUse.Mode.(type) { 469 case *spannerpb.TransactionOptions_ReadOnly_: 470 return singleUse() 471 case *spannerpb.TransactionOptions_ReadWrite_: 472 return singleUse() 473 default: 474 return nil, nil, fmt.Errorf("single use transaction in mode %T not supported", mode) 475 } 476 case *spannerpb.TransactionSelector_Id: 477 sess.mu.Lock() 478 tx, ok := sess.transactions[string(sel.Id)] 479 sess.mu.Unlock() 480 if !ok { 481 return nil, nil, fmt.Errorf("no transaction with id %q", sel.Id) 482 } 483 return tx, func() {}, nil 484 } 485} 486 487func (s *server) ExecuteSql(ctx context.Context, req *spannerpb.ExecuteSqlRequest) (*spannerpb.ResultSet, error) { 488 // Assume this is probably a DML statement or a ping from the session pool. 489 // Queries normally use ExecuteStreamingSql. 490 // TODO: Expand this to support more things. 491 492 // If it is a single-use transaction we assume it is a query. 493 if req.Transaction.GetSelector() == nil || req.Transaction.GetSingleUse().GetReadOnly() != nil { 494 ri, err := s.executeQuery(req) 495 if err != nil { 496 return nil, err 497 } 498 return s.resultSet(ri) 499 } 500 501 obj, ok := req.Transaction.Selector.(*spannerpb.TransactionSelector_Id) 502 if !ok { 503 return nil, fmt.Errorf("unsupported transaction type %T", req.Transaction.Selector) 504 } 505 tid := string(obj.Id) 506 _ = tid // TODO: lookup an existing transaction by ID. 507 508 stmt, err := spansql.ParseDMLStmt(req.Sql) 509 if err != nil { 510 return nil, status.Errorf(codes.InvalidArgument, "bad DML: %v", err) 511 } 512 params, err := parseQueryParams(req.GetParams(), req.ParamTypes) 513 if err != nil { 514 return nil, err 515 } 516 517 s.logf("Executing: %s", stmt.SQL()) 518 if len(params) > 0 { 519 s.logf(" ▹ %v", params) 520 } 521 522 n, err := s.db.Execute(stmt, params) 523 if err != nil { 524 return nil, err 525 } 526 return &spannerpb.ResultSet{ 527 Stats: &spannerpb.ResultSetStats{ 528 RowCount: &spannerpb.ResultSetStats_RowCountExact{int64(n)}, 529 }, 530 }, nil 531} 532 533func (s *server) ExecuteStreamingSql(req *spannerpb.ExecuteSqlRequest, stream spannerpb.Spanner_ExecuteStreamingSqlServer) error { 534 tx, cleanup, err := s.readTx(stream.Context(), req.Session, req.Transaction) 535 if err != nil { 536 return err 537 } 538 defer cleanup() 539 540 ri, err := s.executeQuery(req) 541 if err != nil { 542 return err 543 } 544 return s.readStream(stream.Context(), tx, stream.Send, ri) 545} 546 547func (s *server) executeQuery(req *spannerpb.ExecuteSqlRequest) (ri rowIter, err error) { 548 q, err := spansql.ParseQuery(req.Sql) 549 if err != nil { 550 // TODO: check what code the real Spanner returns here. 551 return nil, status.Errorf(codes.InvalidArgument, "bad query: %v", err) 552 } 553 554 params, err := parseQueryParams(req.GetParams(), req.ParamTypes) 555 if err != nil { 556 return nil, err 557 } 558 559 s.logf("Querying: %s", q.SQL()) 560 if len(params) > 0 { 561 s.logf(" ▹ %v", params) 562 } 563 564 return s.db.Query(q, params) 565} 566 567// TODO: Read 568 569func (s *server) StreamingRead(req *spannerpb.ReadRequest, stream spannerpb.Spanner_StreamingReadServer) error { 570 tx, cleanup, err := s.readTx(stream.Context(), req.Session, req.Transaction) 571 if err != nil { 572 return err 573 } 574 defer cleanup() 575 576 // Bail out if various advanced features are being used. 577 if req.Index != "" { 578 // This is okay; we can still return results. 579 s.logf("Warning: index reads (%q) not supported", req.Index) 580 } 581 if len(req.ResumeToken) > 0 { 582 // This should only happen if we send resume_token ourselves. 583 return fmt.Errorf("read resumption not supported") 584 } 585 if len(req.PartitionToken) > 0 { 586 return fmt.Errorf("partition restrictions not supported") 587 } 588 589 var ri rowIter 590 if req.KeySet.All { 591 s.logf("Reading all from %s (cols: %v)", req.Table, req.Columns) 592 ri, err = s.db.ReadAll(spansql.ID(req.Table), idList(req.Columns), req.Limit) 593 } else { 594 s.logf("Reading rows from %d keys and %d ranges from %s (cols: %v)", len(req.KeySet.Keys), len(req.KeySet.Ranges), req.Table, req.Columns) 595 ri, err = s.db.Read(spansql.ID(req.Table), idList(req.Columns), req.KeySet.Keys, makeKeyRangeList(req.KeySet.Ranges), req.Limit) 596 } 597 if err != nil { 598 return err 599 } 600 601 // TODO: Figure out the right contexts to use here. There's the session one (sess.ctx), 602 // but also this specific RPC one (stream.Context()). Which takes precedence? 603 // They appear to be independent. 604 605 return s.readStream(stream.Context(), tx, stream.Send, ri) 606} 607 608func (s *server) resultSet(ri rowIter) (*spannerpb.ResultSet, error) { 609 rsm, err := s.buildResultSetMetadata(ri) 610 if err != nil { 611 return nil, err 612 } 613 rs := &spannerpb.ResultSet{ 614 Metadata: rsm, 615 } 616 for { 617 row, err := ri.Next() 618 if err == io.EOF { 619 break 620 } else if err != nil { 621 return nil, err 622 } 623 624 values := make([]*structpb.Value, len(row)) 625 for i, x := range row { 626 v, err := spannerValueFromValue(x) 627 if err != nil { 628 return nil, err 629 } 630 values[i] = v 631 } 632 rs.Rows = append(rs.Rows, &structpb.ListValue{Values: values}) 633 } 634 return rs, nil 635} 636 637func (s *server) readStream(ctx context.Context, tx *transaction, send func(*spannerpb.PartialResultSet) error, ri rowIter) error { 638 rsm, err := s.buildResultSetMetadata(ri) 639 if err != nil { 640 return err 641 } 642 643 for { 644 row, err := ri.Next() 645 if err == io.EOF { 646 break 647 } else if err != nil { 648 return err 649 } 650 651 values := make([]*structpb.Value, len(row)) 652 for i, x := range row { 653 v, err := spannerValueFromValue(x) 654 if err != nil { 655 return err 656 } 657 values[i] = v 658 } 659 660 prs := &spannerpb.PartialResultSet{ 661 Metadata: rsm, 662 Values: values, 663 } 664 if err := send(prs); err != nil { 665 return err 666 } 667 668 // ResultSetMetadata is only set for the first PartialResultSet. 669 rsm = nil 670 } 671 672 return nil 673} 674 675func (s *server) buildResultSetMetadata(ri rowIter) (*spannerpb.ResultSetMetadata, error) { 676 // Build the result set metadata. 677 rsm := &spannerpb.ResultSetMetadata{ 678 RowType: &spannerpb.StructType{}, 679 // TODO: transaction info? 680 } 681 for _, ci := range ri.Cols() { 682 st, err := spannerTypeFromType(ci.Type) 683 if err != nil { 684 return nil, err 685 } 686 rsm.RowType.Fields = append(rsm.RowType.Fields, &spannerpb.StructType_Field{ 687 Name: string(ci.Name), 688 Type: st, 689 }) 690 } 691 return rsm, nil 692} 693 694func (s *server) BeginTransaction(ctx context.Context, req *spannerpb.BeginTransactionRequest) (*spannerpb.Transaction, error) { 695 //s.logf("BeginTransaction(%v)", req) 696 697 s.mu.Lock() 698 sess, ok := s.sessions[req.Session] 699 s.mu.Unlock() 700 if !ok { 701 // TODO: what error does the real Spanner return? 702 return nil, status.Errorf(codes.NotFound, "unknown session %q", req.Session) 703 } 704 705 id := genRandomTransaction() 706 tx := s.db.NewTransaction() 707 708 sess.mu.Lock() 709 sess.lastUse = time.Now() 710 sess.transactions[id] = tx 711 sess.mu.Unlock() 712 713 tr := &spannerpb.Transaction{Id: []byte(id)} 714 715 if req.GetOptions().GetReadOnly().GetReturnReadTimestamp() { 716 // Return the last commit timestamp. 717 // This isn't wholly accurate, but may be good enough for simple use cases. 718 tr.ReadTimestamp = timestampProto(s.db.LastCommitTimestamp()) 719 } 720 721 return tr, nil 722} 723 724func (s *server) Commit(ctx context.Context, req *spannerpb.CommitRequest) (resp *spannerpb.CommitResponse, err error) { 725 //s.logf("Commit(%q, %q)", req.Session, req.Transaction) 726 727 obj, ok := req.Transaction.(*spannerpb.CommitRequest_TransactionId) 728 if !ok { 729 return nil, fmt.Errorf("unsupported transaction type %T", req.Transaction) 730 } 731 tid := string(obj.TransactionId) 732 733 tx, err := s.popTx(req.Session, tid) 734 if err != nil { 735 return nil, err 736 } 737 defer func() { 738 if err != nil { 739 tx.Rollback() 740 } 741 }() 742 tx.Start() 743 744 for _, m := range req.Mutations { 745 switch op := m.Operation.(type) { 746 default: 747 return nil, fmt.Errorf("unsupported mutation operation type %T", op) 748 case *spannerpb.Mutation_Insert: 749 ins := op.Insert 750 err := s.db.Insert(tx, spansql.ID(ins.Table), idList(ins.Columns), ins.Values) 751 if err != nil { 752 return nil, err 753 } 754 case *spannerpb.Mutation_Update: 755 up := op.Update 756 err := s.db.Update(tx, spansql.ID(up.Table), idList(up.Columns), up.Values) 757 if err != nil { 758 return nil, err 759 } 760 case *spannerpb.Mutation_InsertOrUpdate: 761 iou := op.InsertOrUpdate 762 err := s.db.InsertOrUpdate(tx, spansql.ID(iou.Table), idList(iou.Columns), iou.Values) 763 if err != nil { 764 return nil, err 765 } 766 case *spannerpb.Mutation_Delete_: 767 del := op.Delete 768 ks := del.KeySet 769 770 err := s.db.Delete(tx, spansql.ID(del.Table), ks.Keys, makeKeyRangeList(ks.Ranges), ks.All) 771 if err != nil { 772 return nil, err 773 } 774 } 775 776 } 777 778 ts, err := tx.Commit() 779 if err != nil { 780 return nil, err 781 } 782 783 return &spannerpb.CommitResponse{ 784 CommitTimestamp: timestampProto(ts), 785 }, nil 786} 787 788func (s *server) Rollback(ctx context.Context, req *spannerpb.RollbackRequest) (*emptypb.Empty, error) { 789 s.logf("Rollback(%v)", req) 790 791 tx, err := s.popTx(req.Session, string(req.TransactionId)) 792 if err != nil { 793 return nil, err 794 } 795 796 tx.Rollback() 797 798 return &emptypb.Empty{}, nil 799} 800 801// TODO: PartitionQuery, PartitionRead 802 803func parseQueryParams(p *structpb.Struct, types map[string]*spannerpb.Type) (queryParams, error) { 804 params := make(queryParams) 805 for k, v := range p.GetFields() { 806 p, err := parseQueryParam(v, types[k]) 807 if err != nil { 808 return nil, err 809 } 810 params[k] = p 811 } 812 return params, nil 813} 814 815func parseQueryParam(v *structpb.Value, typ *spannerpb.Type) (queryParam, error) { 816 // TODO: Use valForType and typeFromSpannerType more comprehensively here? 817 // They are only used for StringValue vs, since that's what mostly needs parsing. 818 819 rawv := v 820 switch v := v.Kind.(type) { 821 default: 822 return queryParam{}, fmt.Errorf("unsupported well-known type value kind %T", v) 823 case *structpb.Value_NullValue: 824 return queryParam{Value: nil}, nil // TODO: set a type? 825 case *structpb.Value_BoolValue: 826 return queryParam{Value: v.BoolValue, Type: boolType}, nil 827 case *structpb.Value_NumberValue: 828 return queryParam{Value: v.NumberValue, Type: float64Type}, nil 829 case *structpb.Value_StringValue: 830 t, err := typeFromSpannerType(typ) 831 if err != nil { 832 return queryParam{}, err 833 } 834 val, err := valForType(rawv, t) 835 if err != nil { 836 return queryParam{}, err 837 } 838 return queryParam{Value: val, Type: t}, nil 839 case *structpb.Value_ListValue: 840 var list []interface{} 841 for _, elem := range v.ListValue.Values { 842 // TODO: Change the type parameter passed through? We only look at the code. 843 p, err := parseQueryParam(elem, typ) 844 if err != nil { 845 return queryParam{}, err 846 } 847 list = append(list, p.Value) 848 } 849 t, err := typeFromSpannerType(typ) 850 if err != nil { 851 return queryParam{}, err 852 } 853 return queryParam{Value: list, Type: t}, nil 854 } 855} 856 857func typeFromSpannerType(st *spannerpb.Type) (spansql.Type, error) { 858 switch st.Code { 859 default: 860 return spansql.Type{}, fmt.Errorf("unhandled spanner type code %v", st.Code) 861 case spannerpb.TypeCode_BOOL: 862 return spansql.Type{Base: spansql.Bool}, nil 863 case spannerpb.TypeCode_INT64: 864 return spansql.Type{Base: spansql.Int64}, nil 865 case spannerpb.TypeCode_FLOAT64: 866 return spansql.Type{Base: spansql.Float64}, nil 867 case spannerpb.TypeCode_TIMESTAMP: 868 return spansql.Type{Base: spansql.Timestamp}, nil 869 case spannerpb.TypeCode_DATE: 870 return spansql.Type{Base: spansql.Date}, nil 871 case spannerpb.TypeCode_STRING: 872 return spansql.Type{Base: spansql.String}, nil // no len 873 case spannerpb.TypeCode_BYTES: 874 return spansql.Type{Base: spansql.Bytes}, nil // no len 875 case spannerpb.TypeCode_ARRAY: 876 typ, err := typeFromSpannerType(st.ArrayElementType) 877 if err != nil { 878 return spansql.Type{}, err 879 } 880 typ.Array = true 881 return typ, nil 882 } 883} 884 885func spannerTypeFromType(typ spansql.Type) (*spannerpb.Type, error) { 886 var code spannerpb.TypeCode 887 switch typ.Base { 888 default: 889 return nil, fmt.Errorf("unhandled base type %d", typ.Base) 890 case spansql.Bool: 891 code = spannerpb.TypeCode_BOOL 892 case spansql.Int64: 893 code = spannerpb.TypeCode_INT64 894 case spansql.Float64: 895 code = spannerpb.TypeCode_FLOAT64 896 case spansql.String: 897 code = spannerpb.TypeCode_STRING 898 case spansql.Bytes: 899 code = spannerpb.TypeCode_BYTES 900 case spansql.Date: 901 code = spannerpb.TypeCode_DATE 902 case spansql.Timestamp: 903 code = spannerpb.TypeCode_TIMESTAMP 904 } 905 st := &spannerpb.Type{Code: code} 906 if typ.Array { 907 st = &spannerpb.Type{ 908 Code: spannerpb.TypeCode_ARRAY, 909 ArrayElementType: st, 910 } 911 } 912 return st, nil 913} 914 915func spannerValueFromValue(x interface{}) (*structpb.Value, error) { 916 switch x := x.(type) { 917 default: 918 return nil, fmt.Errorf("unhandled database value type %T", x) 919 case bool: 920 return &structpb.Value{Kind: &structpb.Value_BoolValue{x}}, nil 921 case int64: 922 // The Spanner int64 is actually a decimal string. 923 s := strconv.FormatInt(x, 10) 924 return &structpb.Value{Kind: &structpb.Value_StringValue{s}}, nil 925 case float64: 926 return &structpb.Value{Kind: &structpb.Value_NumberValue{x}}, nil 927 case string: 928 return &structpb.Value{Kind: &structpb.Value_StringValue{x}}, nil 929 case []byte: 930 return &structpb.Value{Kind: &structpb.Value_StringValue{base64.StdEncoding.EncodeToString(x)}}, nil 931 case civil.Date: 932 // RFC 3339 date format. 933 return &structpb.Value{Kind: &structpb.Value_StringValue{x.String()}}, nil 934 case time.Time: 935 // RFC 3339 timestamp format with zone Z. 936 s := x.Format("2006-01-02T15:04:05.999999999Z") 937 return &structpb.Value{Kind: &structpb.Value_StringValue{s}}, nil 938 case nil: 939 return &structpb.Value{Kind: &structpb.Value_NullValue{}}, nil 940 case []interface{}: 941 var vs []*structpb.Value 942 for _, elem := range x { 943 v, err := spannerValueFromValue(elem) 944 if err != nil { 945 return nil, err 946 } 947 vs = append(vs, v) 948 } 949 return &structpb.Value{Kind: &structpb.Value_ListValue{ 950 &structpb.ListValue{Values: vs}, 951 }}, nil 952 } 953} 954 955func makeKeyRangeList(ranges []*spannerpb.KeyRange) keyRangeList { 956 var krl keyRangeList 957 for _, r := range ranges { 958 krl = append(krl, makeKeyRange(r)) 959 } 960 return krl 961} 962 963func makeKeyRange(r *spannerpb.KeyRange) *keyRange { 964 var kr keyRange 965 switch s := r.StartKeyType.(type) { 966 case *spannerpb.KeyRange_StartClosed: 967 kr.start = s.StartClosed 968 kr.startClosed = true 969 case *spannerpb.KeyRange_StartOpen: 970 kr.start = s.StartOpen 971 } 972 switch e := r.EndKeyType.(type) { 973 case *spannerpb.KeyRange_EndClosed: 974 kr.end = e.EndClosed 975 kr.endClosed = true 976 case *spannerpb.KeyRange_EndOpen: 977 kr.end = e.EndOpen 978 } 979 return &kr 980} 981 982func idList(ss []string) (ids []spansql.ID) { 983 for _, s := range ss { 984 ids = append(ids, spansql.ID(s)) 985 } 986 return 987} 988