1/* 2Copyright 2019 Google LLC 3 4Licensed under the Apache License, Version 2.0 (the "License"); 5you may not use this file except in compliance with the License. 6You may obtain a copy of the License at 7 8 http://www.apache.org/licenses/LICENSE-2.0 9 10Unless required by applicable law or agreed to in writing, software 11distributed under the License is distributed on an "AS IS" BASIS, 12WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13See the License for the specific language governing permissions and 14limitations under the License. 15*/ 16 17package spannertest 18 19// This file contains the implementation of the Spanner fake itself, 20// namely the part behind the RPC interface. 21 22// TODO: missing transactionality in a serious way! 23 24import ( 25 "bytes" 26 "encoding/base64" 27 "fmt" 28 "sort" 29 "strconv" 30 "strings" 31 "sync" 32 "time" 33 34 "google.golang.org/grpc/codes" 35 "google.golang.org/grpc/status" 36 37 structpb "github.com/golang/protobuf/ptypes/struct" 38 39 "cloud.google.com/go/civil" 40 "cloud.google.com/go/spanner/spansql" 41) 42 43type database struct { 44 mu sync.Mutex 45 lastTS time.Time // last commit timestamp 46 tables map[spansql.ID]*table 47 indexes map[spansql.ID]struct{} // only record their existence 48 49 rwMu sync.Mutex // held by read-write transactions 50} 51 52type table struct { 53 mu sync.Mutex 54 55 // Information about the table columns. 56 // They are reordered on table creation so the primary key columns come first. 57 cols []colInfo 58 colIndex map[spansql.ID]int // col name to index 59 origIndex map[spansql.ID]int // original index of each column upon construction 60 pkCols int // number of primary key columns (may be 0) 61 pkDesc []bool // whether each primary key column is in descending order 62 63 // Rows are stored in primary key order. 64 rows []row 65} 66 67// colInfo represents information about a column in a table or result set. 68type colInfo struct { 69 Name spansql.ID 70 Type spansql.Type 71 NotNull bool // only set for table columns 72 AggIndex int // Index+1 of SELECT list for which this is an aggregate value. 73 Alias spansql.PathExp // an alternate name for this column (result sets only) 74} 75 76// commitTimestampSentinel is a sentinel value for TIMESTAMP fields with allow_commit_timestamp=true. 77// It is accepted, but never stored. 78var commitTimestampSentinel = &struct{}{} 79 80// transaction records information about a running transaction. 81// This is not safe for concurrent use. 82type transaction struct { 83 // readOnly is whether this transaction was constructed 84 // for read-only use, and should yield errors if used 85 // to perform a mutation. 86 readOnly bool 87 88 d *database 89 commitTimestamp time.Time // not set if readOnly 90 unlock func() // may be nil 91} 92 93func (d *database) NewReadOnlyTransaction() *transaction { 94 return &transaction{ 95 readOnly: true, 96 } 97} 98 99func (d *database) NewTransaction() *transaction { 100 return &transaction{ 101 d: d, 102 } 103} 104 105// Start starts the transaction and commits to a specific commit timestamp. 106// This also locks out any other read-write transaction on this database 107// until Commit/Rollback are called. 108func (tx *transaction) Start() { 109 // Commit timestamps are only guaranteed to be unique 110 // when transactions write to overlapping sets of fields. 111 // This simulated database exceeds that guarantee. 112 113 // Grab rwMu for the duration of this transaction. 114 // Take it before d.mu so we don't hold that lock 115 // while waiting for d.rwMu, which is held for longer. 116 tx.d.rwMu.Lock() 117 118 tx.d.mu.Lock() 119 const tsRes = 1 * time.Microsecond 120 now := time.Now().UTC().Truncate(tsRes) 121 if !now.After(tx.d.lastTS) { 122 now = tx.d.lastTS.Add(tsRes) 123 } 124 tx.d.lastTS = now 125 tx.d.mu.Unlock() 126 127 tx.commitTimestamp = now 128 tx.unlock = tx.d.rwMu.Unlock 129} 130 131func (tx *transaction) checkMutable() error { 132 if tx.readOnly { 133 // TODO: is this the right status? 134 return status.Errorf(codes.InvalidArgument, "transaction is read-only") 135 } 136 return nil 137} 138 139func (tx *transaction) Commit() (time.Time, error) { 140 if tx.unlock != nil { 141 tx.unlock() 142 } 143 return tx.commitTimestamp, nil 144} 145 146func (tx *transaction) Rollback() { 147 if tx.unlock != nil { 148 tx.unlock() 149 } 150 // TODO: actually rollback 151} 152 153/* 154row represents a list of data elements. 155 156The mapping between Spanner types and Go types internal to this package are: 157 BOOL bool 158 INT64 int64 159 FLOAT64 float64 160 STRING string 161 BYTES []byte 162 DATE civil.Date 163 TIMESTAMP time.Time (location set to UTC) 164 ARRAY<T> []interface{} 165 STRUCT TODO 166*/ 167type row []interface{} 168 169func (r row) copyDataElem(index int) interface{} { 170 v := r[index] 171 if is, ok := v.([]interface{}); ok { 172 // Deep-copy array values. 173 v = append([]interface{}(nil), is...) 174 } 175 return v 176} 177 178// copyData returns a copy of the row. 179func (r row) copyAllData() row { 180 dst := make(row, 0, len(r)) 181 for i := range r { 182 dst = append(dst, r.copyDataElem(i)) 183 } 184 return dst 185} 186 187// copyData returns a copy of a subset of a row. 188func (r row) copyData(indexes []int) row { 189 if len(indexes) == 0 { 190 return nil 191 } 192 dst := make(row, 0, len(indexes)) 193 for _, i := range indexes { 194 dst = append(dst, r.copyDataElem(i)) 195 } 196 return dst 197} 198 199func (d *database) LastCommitTimestamp() time.Time { 200 d.mu.Lock() 201 defer d.mu.Unlock() 202 return d.lastTS 203} 204 205func (d *database) GetDDL() []spansql.DDLStmt { 206 // This lacks fidelity, but captures the details we support. 207 d.mu.Lock() 208 defer d.mu.Unlock() 209 210 var stmts []spansql.DDLStmt 211 212 for name, t := range d.tables { 213 ct := &spansql.CreateTable{ 214 Name: name, 215 } 216 217 t.mu.Lock() 218 for i, col := range t.cols { 219 ct.Columns = append(ct.Columns, spansql.ColumnDef{ 220 Name: col.Name, 221 Type: col.Type, 222 NotNull: col.NotNull, 223 // TODO: AllowCommitTimestamp 224 }) 225 if i < t.pkCols { 226 ct.PrimaryKey = append(ct.PrimaryKey, spansql.KeyPart{ 227 Column: col.Name, 228 Desc: t.pkDesc[i], 229 }) 230 } 231 } 232 t.mu.Unlock() 233 234 stmts = append(stmts, ct) 235 } 236 237 return stmts 238} 239 240func (d *database) ApplyDDL(stmt spansql.DDLStmt) *status.Status { 241 d.mu.Lock() 242 defer d.mu.Unlock() 243 244 // Lazy init. 245 if d.tables == nil { 246 d.tables = make(map[spansql.ID]*table) 247 } 248 if d.indexes == nil { 249 d.indexes = make(map[spansql.ID]struct{}) 250 } 251 252 switch stmt := stmt.(type) { 253 default: 254 return status.Newf(codes.Unimplemented, "unhandled DDL statement type %T", stmt) 255 case *spansql.CreateTable: 256 if _, ok := d.tables[stmt.Name]; ok { 257 return status.Newf(codes.AlreadyExists, "table %s already exists", stmt.Name) 258 } 259 if len(stmt.PrimaryKey) == 0 { 260 return status.Newf(codes.InvalidArgument, "table %s has no primary key", stmt.Name) 261 } 262 263 // TODO: check stmt.Interleave details. 264 265 // Record original column ordering. 266 orig := make(map[spansql.ID]int) 267 for i, col := range stmt.Columns { 268 orig[col.Name] = i 269 } 270 271 // Move primary keys first, preserving their order. 272 pk := make(map[spansql.ID]int) 273 var pkDesc []bool 274 for i, kp := range stmt.PrimaryKey { 275 pk[kp.Column] = -1000 + i 276 pkDesc = append(pkDesc, kp.Desc) 277 } 278 sort.SliceStable(stmt.Columns, func(i, j int) bool { 279 a, b := pk[stmt.Columns[i].Name], pk[stmt.Columns[j].Name] 280 return a < b 281 }) 282 283 t := &table{ 284 colIndex: make(map[spansql.ID]int), 285 origIndex: orig, 286 pkCols: len(pk), 287 pkDesc: pkDesc, 288 } 289 for _, cd := range stmt.Columns { 290 if st := t.addColumn(cd, true); st.Code() != codes.OK { 291 return st 292 } 293 } 294 for col := range pk { 295 if _, ok := t.colIndex[col]; !ok { 296 return status.Newf(codes.InvalidArgument, "primary key column %q not in table", col) 297 } 298 } 299 d.tables[stmt.Name] = t 300 return nil 301 case *spansql.CreateIndex: 302 if _, ok := d.indexes[stmt.Name]; ok { 303 return status.Newf(codes.AlreadyExists, "index %s already exists", stmt.Name) 304 } 305 d.indexes[stmt.Name] = struct{}{} 306 return nil 307 case *spansql.DropTable: 308 if _, ok := d.tables[stmt.Name]; !ok { 309 return status.Newf(codes.NotFound, "no table named %s", stmt.Name) 310 } 311 // TODO: check for indexes on this table. 312 delete(d.tables, stmt.Name) 313 return nil 314 case *spansql.DropIndex: 315 if _, ok := d.indexes[stmt.Name]; !ok { 316 return status.Newf(codes.NotFound, "no index named %s", stmt.Name) 317 } 318 delete(d.indexes, stmt.Name) 319 return nil 320 case *spansql.AlterTable: 321 t, ok := d.tables[stmt.Name] 322 if !ok { 323 return status.Newf(codes.NotFound, "no table named %s", stmt.Name) 324 } 325 switch alt := stmt.Alteration.(type) { 326 default: 327 return status.Newf(codes.Unimplemented, "unhandled DDL table alteration type %T", alt) 328 case spansql.AddColumn: 329 if st := t.addColumn(alt.Def, false); st.Code() != codes.OK { 330 return st 331 } 332 return nil 333 case spansql.DropColumn: 334 if st := t.dropColumn(alt.Name); st.Code() != codes.OK { 335 return st 336 } 337 return nil 338 case spansql.AlterColumn: 339 if st := t.alterColumn(alt); st.Code() != codes.OK { 340 return st 341 } 342 return nil 343 } 344 } 345 346} 347 348func (d *database) table(tbl spansql.ID) (*table, error) { 349 d.mu.Lock() 350 defer d.mu.Unlock() 351 352 t, ok := d.tables[tbl] 353 if !ok { 354 return nil, status.Errorf(codes.NotFound, "no table named %s", tbl) 355 } 356 return t, nil 357} 358 359// writeValues executes a write option (Insert, Update, etc.). 360func (d *database) writeValues(tx *transaction, tbl spansql.ID, cols []spansql.ID, values []*structpb.ListValue, f func(t *table, colIndexes []int, r row) error) error { 361 if err := tx.checkMutable(); err != nil { 362 return err 363 } 364 365 t, err := d.table(tbl) 366 if err != nil { 367 return err 368 } 369 370 t.mu.Lock() 371 defer t.mu.Unlock() 372 373 colIndexes, err := t.colIndexes(cols) 374 if err != nil { 375 return err 376 } 377 revIndex := make(map[int]int) // table index to col index 378 for j, i := range colIndexes { 379 revIndex[i] = j 380 } 381 382 for pki := 0; pki < t.pkCols; pki++ { 383 _, ok := revIndex[pki] 384 if !ok { 385 return status.Errorf(codes.InvalidArgument, "primary key column %s not included in write", t.cols[pki].Name) 386 } 387 } 388 389 for _, vs := range values { 390 if len(vs.Values) != len(colIndexes) { 391 return status.Errorf(codes.InvalidArgument, "row of %d values can't be written to %d columns", len(vs.Values), len(colIndexes)) 392 } 393 394 r := make(row, len(t.cols)) 395 for j, v := range vs.Values { 396 i := colIndexes[j] 397 398 x, err := valForType(v, t.cols[i].Type) 399 if err != nil { 400 return err 401 } 402 if x == commitTimestampSentinel { 403 x = tx.commitTimestamp 404 } 405 if x == nil && t.cols[i].NotNull { 406 return status.Errorf(codes.FailedPrecondition, "%s must not be NULL in table %s", t.cols[i].Name, tbl) 407 } 408 409 r[i] = x 410 } 411 // TODO: enforce that provided timestamp for commit_timestamp=true columns 412 // are not ahead of the transaction's commit timestamp. 413 414 if err := f(t, colIndexes, r); err != nil { 415 return err 416 } 417 } 418 419 return nil 420} 421 422func (d *database) Insert(tx *transaction, tbl spansql.ID, cols []spansql.ID, values []*structpb.ListValue) error { 423 return d.writeValues(tx, tbl, cols, values, func(t *table, colIndexes []int, r row) error { 424 pk := r[:t.pkCols] 425 rowNum, found := t.rowForPK(pk) 426 if found { 427 return status.Errorf(codes.AlreadyExists, "row already in table") 428 } 429 t.insertRow(rowNum, r) 430 return nil 431 }) 432} 433 434func (d *database) Update(tx *transaction, tbl spansql.ID, cols []spansql.ID, values []*structpb.ListValue) error { 435 return d.writeValues(tx, tbl, cols, values, func(t *table, colIndexes []int, r row) error { 436 if t.pkCols == 0 { 437 return status.Errorf(codes.InvalidArgument, "cannot update table %s with no columns in primary key", tbl) 438 } 439 pk := r[:t.pkCols] 440 rowNum, found := t.rowForPK(pk) 441 if !found { 442 // TODO: is this the right way to return `NOT_FOUND`? 443 return status.Errorf(codes.NotFound, "row not in table") 444 } 445 446 for _, i := range colIndexes { 447 t.rows[rowNum][i] = r[i] 448 } 449 return nil 450 }) 451} 452 453func (d *database) InsertOrUpdate(tx *transaction, tbl spansql.ID, cols []spansql.ID, values []*structpb.ListValue) error { 454 return d.writeValues(tx, tbl, cols, values, func(t *table, colIndexes []int, r row) error { 455 pk := r[:t.pkCols] 456 rowNum, found := t.rowForPK(pk) 457 if !found { 458 // New row; do an insert. 459 t.insertRow(rowNum, r) 460 } else { 461 // Existing row; do an update. 462 for _, i := range colIndexes { 463 t.rows[rowNum][i] = r[i] 464 } 465 } 466 return nil 467 }) 468} 469 470// TODO: Replace 471 472func (d *database) Delete(tx *transaction, table spansql.ID, keys []*structpb.ListValue, keyRanges keyRangeList, all bool) error { 473 if err := tx.checkMutable(); err != nil { 474 return err 475 } 476 477 t, err := d.table(table) 478 if err != nil { 479 return err 480 } 481 482 t.mu.Lock() 483 defer t.mu.Unlock() 484 485 if all { 486 t.rows = nil 487 return nil 488 } 489 490 for _, key := range keys { 491 pk, err := t.primaryKey(key.Values) 492 if err != nil { 493 return err 494 } 495 // Not an error if the key does not exist. 496 rowNum, found := t.rowForPK(pk) 497 if found { 498 copy(t.rows[rowNum:], t.rows[rowNum+1:]) 499 t.rows = t.rows[:len(t.rows)-1] 500 } 501 } 502 503 for _, r := range keyRanges { 504 r.startKey, err = t.primaryKeyPrefix(r.start.Values) 505 if err != nil { 506 return err 507 } 508 r.endKey, err = t.primaryKeyPrefix(r.end.Values) 509 if err != nil { 510 return err 511 } 512 startRow, endRow := t.findRange(r) 513 if n := endRow - startRow; n > 0 { 514 copy(t.rows[startRow:], t.rows[endRow:]) 515 t.rows = t.rows[:len(t.rows)-n] 516 } 517 } 518 519 return nil 520} 521 522// readTable executes a read option (Read, ReadAll). 523func (d *database) readTable(table spansql.ID, cols []spansql.ID, f func(*table, *rawIter, []int) error) (*rawIter, error) { 524 t, err := d.table(table) 525 if err != nil { 526 return nil, err 527 } 528 529 t.mu.Lock() 530 defer t.mu.Unlock() 531 532 colIndexes, err := t.colIndexes(cols) 533 if err != nil { 534 return nil, err 535 } 536 537 ri := &rawIter{} 538 for _, i := range colIndexes { 539 ri.cols = append(ri.cols, t.cols[i]) 540 } 541 return ri, f(t, ri, colIndexes) 542} 543 544func (d *database) Read(tbl spansql.ID, cols []spansql.ID, keys []*structpb.ListValue, keyRanges keyRangeList, limit int64) (rowIter, error) { 545 // The real Cloud Spanner returns an error if the key set is empty by definition. 546 // That doesn't seem to be well-defined, but it is a common error to attempt a read with no keys, 547 // so catch that here and return a representative error. 548 if len(keys) == 0 && len(keyRanges) == 0 { 549 return nil, status.Error(codes.Unimplemented, "Cloud Spanner does not support reading no keys") 550 } 551 552 return d.readTable(tbl, cols, func(t *table, ri *rawIter, colIndexes []int) error { 553 // "If the same key is specified multiple times in the set (for 554 // example if two ranges, two keys, or a key and a range 555 // overlap), Cloud Spanner behaves as if the key were only 556 // specified once." 557 done := make(map[int]bool) // row numbers we've included in ri. 558 559 // Specific keys. 560 for _, key := range keys { 561 pk, err := t.primaryKey(key.Values) 562 if err != nil { 563 return err 564 } 565 // Not an error if the key does not exist. 566 rowNum, found := t.rowForPK(pk) 567 if !found { 568 continue 569 } 570 if done[rowNum] { 571 continue 572 } 573 done[rowNum] = true 574 ri.add(t.rows[rowNum], colIndexes) 575 if limit > 0 && len(ri.rows) >= int(limit) { 576 return nil 577 } 578 } 579 580 // Key ranges. 581 for _, r := range keyRanges { 582 var err error 583 r.startKey, err = t.primaryKeyPrefix(r.start.Values) 584 if err != nil { 585 return err 586 } 587 r.endKey, err = t.primaryKeyPrefix(r.end.Values) 588 if err != nil { 589 return err 590 } 591 startRow, endRow := t.findRange(r) 592 for rowNum := startRow; rowNum < endRow; rowNum++ { 593 if done[rowNum] { 594 continue 595 } 596 done[rowNum] = true 597 ri.add(t.rows[rowNum], colIndexes) 598 if limit > 0 && len(ri.rows) >= int(limit) { 599 return nil 600 } 601 } 602 } 603 604 return nil 605 }) 606} 607 608func (d *database) ReadAll(tbl spansql.ID, cols []spansql.ID, limit int64) (*rawIter, error) { 609 return d.readTable(tbl, cols, func(t *table, ri *rawIter, colIndexes []int) error { 610 for _, r := range t.rows { 611 ri.add(r, colIndexes) 612 if limit > 0 && len(ri.rows) >= int(limit) { 613 break 614 } 615 } 616 return nil 617 }) 618} 619 620func (t *table) addColumn(cd spansql.ColumnDef, newTable bool) *status.Status { 621 if !newTable && cd.NotNull { 622 return status.Newf(codes.InvalidArgument, "new non-key columns cannot be NOT NULL") 623 } 624 625 if _, ok := t.colIndex[cd.Name]; ok { 626 return status.Newf(codes.AlreadyExists, "column %s already exists", cd.Name) 627 } 628 629 t.mu.Lock() 630 defer t.mu.Unlock() 631 632 if len(t.rows) > 0 { 633 if cd.NotNull { 634 // TODO: what happens in this case? 635 return status.Newf(codes.Unimplemented, "can't add NOT NULL columns to non-empty tables yet") 636 } 637 for i := range t.rows { 638 t.rows[i] = append(t.rows[i], nil) 639 } 640 } 641 642 t.cols = append(t.cols, colInfo{ 643 Name: cd.Name, 644 Type: cd.Type, 645 NotNull: cd.NotNull, 646 }) 647 t.colIndex[cd.Name] = len(t.cols) - 1 648 if !newTable { 649 t.origIndex[cd.Name] = len(t.cols) - 1 650 } 651 652 return nil 653} 654 655func (t *table) dropColumn(name spansql.ID) *status.Status { 656 // Only permit dropping non-key columns that aren't part of a secondary index. 657 // We don't support indexes, so only check that it isn't part of the primary key. 658 659 t.mu.Lock() 660 defer t.mu.Unlock() 661 662 ci, ok := t.colIndex[name] 663 if !ok { 664 // TODO: What's the right response code? 665 return status.Newf(codes.InvalidArgument, "unknown column %q", name) 666 } 667 if ci < t.pkCols { 668 // TODO: What's the right response code? 669 return status.Newf(codes.InvalidArgument, "can't drop primary key column %q", name) 670 } 671 672 // Remove from cols and colIndex, and renumber colIndex and origIndex. 673 t.cols = append(t.cols[:ci], t.cols[ci+1:]...) 674 delete(t.colIndex, name) 675 for i, col := range t.cols { 676 t.colIndex[col.Name] = i 677 } 678 pre := t.origIndex[name] 679 delete(t.origIndex, name) 680 for n, i := range t.origIndex { 681 if i > pre { 682 t.origIndex[n]-- 683 } 684 } 685 686 // Drop data. 687 for i := range t.rows { 688 t.rows[i] = append(t.rows[i][:ci], t.rows[i][ci+1:]...) 689 } 690 691 return nil 692} 693 694func (t *table) alterColumn(alt spansql.AlterColumn) *status.Status { 695 // Supported changes here are: 696 // Add NOT NULL to a non-key column, excluding ARRAY columns. 697 // Remove NOT NULL from a non-key column. 698 // Change a STRING column to a BYTES column or a BYTES column to a STRING column. 699 // Increase or decrease the length limit for a STRING or BYTES type (including to MAX). 700 // Enable or disable commit timestamps in value and primary key columns. 701 // https://cloud.google.com/spanner/docs/schema-updates#supported-updates 702 703 // TODO: codes.InvalidArgument is used throughout here for reporting errors, 704 // but that has not been validated against the real Spanner. 705 706 sct, ok := alt.Alteration.(spansql.SetColumnType) 707 if !ok { 708 return status.Newf(codes.InvalidArgument, "unsupported ALTER COLUMN %s", alt.SQL()) 709 } 710 711 t.mu.Lock() 712 defer t.mu.Unlock() 713 714 ci, ok := t.colIndex[alt.Name] 715 if !ok { 716 return status.Newf(codes.InvalidArgument, "unknown column %q", alt.Name) 717 } 718 719 oldT, newT := t.cols[ci].Type, sct.Type 720 stringOrBytes := func(bt spansql.TypeBase) bool { return bt == spansql.String || bt == spansql.Bytes } 721 722 // First phase: Check the validity of the change. 723 // TODO: Don't permit changes to allow commit timestamps. 724 if !t.cols[ci].NotNull && sct.NotNull { 725 // Adding NOT NULL is not permitted for primary key columns or array typed columns. 726 if ci < t.pkCols { 727 return status.Newf(codes.InvalidArgument, "cannot set NOT NULL on primary key column %q", alt.Name) 728 } 729 if oldT.Array { 730 return status.Newf(codes.InvalidArgument, "cannot set NOT NULL on array-typed column %q", alt.Name) 731 } 732 // Validate that there are no NULL values. 733 for _, row := range t.rows { 734 if row[ci] == nil { 735 return status.Newf(codes.InvalidArgument, "cannot set NOT NULL on column %q that contains NULL values", alt.Name) 736 } 737 } 738 } 739 var conv func(x interface{}) interface{} 740 if stringOrBytes(oldT.Base) && stringOrBytes(newT.Base) && !oldT.Array && !newT.Array { 741 // Change between STRING and BYTES is fine, as is increasing/decreasing the length limit. 742 // TODO: This should permit array conversions too. 743 // TODO: Validate data; length limit changes should be rejected if they'd lead to data loss, for instance. 744 if oldT.Base == spansql.Bytes && newT.Base == spansql.String { 745 conv = func(x interface{}) interface{} { return string(x.([]byte)) } 746 } else if oldT.Base == spansql.String && newT.Base == spansql.Bytes { 747 conv = func(x interface{}) interface{} { return []byte(x.(string)) } 748 } 749 } else if oldT == newT { 750 // Same type; only NOT NULL changes. 751 } else { // TODO: Support other alterations. 752 return status.Newf(codes.InvalidArgument, "unsupported ALTER COLUMN %s", alt.SQL()) 753 } 754 755 // Second phase: Make type transformations. 756 t.cols[ci].NotNull = sct.NotNull 757 t.cols[ci].Type = newT 758 if conv != nil { 759 for _, row := range t.rows { 760 if row[ci] != nil { // NULL stays as NULL. 761 row[ci] = conv(row[ci]) 762 } 763 } 764 } 765 return nil 766} 767 768func (t *table) insertRow(rowNum int, r row) { 769 t.rows = append(t.rows, nil) 770 copy(t.rows[rowNum+1:], t.rows[rowNum:]) 771 t.rows[rowNum] = r 772} 773 774// findRange finds the rows included in the key range, 775// reporting it as a half-open interval. 776// r.startKey and r.endKey should be populated. 777func (t *table) findRange(r *keyRange) (int, int) { 778 // startRow is the first row matching the range. 779 startRow := sort.Search(len(t.rows), func(i int) bool { 780 return rowCmp(r.startKey, t.rows[i][:t.pkCols], t.pkDesc) <= 0 781 }) 782 if startRow == len(t.rows) { 783 return startRow, startRow 784 } 785 if !r.startClosed && rowCmp(r.startKey, t.rows[startRow][:t.pkCols], t.pkDesc) == 0 { 786 startRow++ 787 } 788 789 // endRow is one more than the last row matching the range. 790 endRow := sort.Search(len(t.rows), func(i int) bool { 791 return rowCmp(r.endKey, t.rows[i][:t.pkCols], t.pkDesc) < 0 792 }) 793 if !r.endClosed && rowCmp(r.endKey, t.rows[endRow-1][:t.pkCols], t.pkDesc) == 0 { 794 endRow-- 795 } 796 797 return startRow, endRow 798} 799 800// colIndexes returns the indexes for the named columns. 801func (t *table) colIndexes(cols []spansql.ID) ([]int, error) { 802 var is []int 803 for _, col := range cols { 804 i, ok := t.colIndex[col] 805 if !ok { 806 return nil, status.Errorf(codes.InvalidArgument, "column %s not in table", col) 807 } 808 is = append(is, i) 809 } 810 return is, nil 811} 812 813// primaryKey constructs the internal representation of a primary key. 814// The list of given values must be in 1:1 correspondence with the primary key of the table. 815func (t *table) primaryKey(values []*structpb.Value) ([]interface{}, error) { 816 if len(values) != t.pkCols { 817 return nil, status.Errorf(codes.InvalidArgument, "primary key length mismatch: got %d values, table has %d", len(values), t.pkCols) 818 } 819 return t.primaryKeyPrefix(values) 820} 821 822// primaryKeyPrefix constructs the internal representation of a primary key prefix. 823func (t *table) primaryKeyPrefix(values []*structpb.Value) ([]interface{}, error) { 824 if len(values) > t.pkCols { 825 return nil, status.Errorf(codes.InvalidArgument, "primary key length too long: got %d values, table has %d", len(values), t.pkCols) 826 } 827 828 var pk []interface{} 829 for i, value := range values { 830 v, err := valForType(value, t.cols[i].Type) 831 if err != nil { 832 return nil, err 833 } 834 pk = append(pk, v) 835 } 836 return pk, nil 837} 838 839// rowForPK returns the index of t.rows that holds the row for the given primary key, and true. 840// If the given primary key isn't found, it returns the row that should hold it, and false. 841func (t *table) rowForPK(pk []interface{}) (row int, found bool) { 842 if len(pk) != t.pkCols { 843 panic(fmt.Sprintf("primary key length mismatch: got %d values, table has %d", len(pk), t.pkCols)) 844 } 845 846 i := sort.Search(len(t.rows), func(i int) bool { 847 return rowCmp(pk, t.rows[i][:t.pkCols], t.pkDesc) <= 0 848 }) 849 if i == len(t.rows) { 850 return i, false 851 } 852 return i, rowEqual(pk, t.rows[i][:t.pkCols]) 853} 854 855// rowCmp compares two rows, returning -1/0/+1. 856// The desc arg indicates whether each column is in a descending order. 857// This is used for primary key matching and so doesn't support array/struct types. 858// a is permitted to be shorter than b. 859func rowCmp(a, b []interface{}, desc []bool) int { 860 for i := 0; i < len(a); i++ { 861 if cmp := compareVals(a[i], b[i]); cmp != 0 { 862 if desc[i] { 863 cmp = -cmp 864 } 865 return cmp 866 } 867 } 868 return 0 869} 870 871// rowEqual reports whether two rows are equal. 872// This doesn't support array/struct types. 873func rowEqual(a, b []interface{}) bool { 874 for i := 0; i < len(a); i++ { 875 if compareVals(a[i], b[i]) != 0 { 876 return false 877 } 878 } 879 return true 880} 881 882// valForType converts a value from its RPC form into its internal representation. 883func valForType(v *structpb.Value, t spansql.Type) (interface{}, error) { 884 if _, ok := v.Kind.(*structpb.Value_NullValue); ok { 885 return nil, nil 886 } 887 888 if lv, ok := v.Kind.(*structpb.Value_ListValue); ok && t.Array { 889 et := t // element type 890 et.Array = false 891 892 // Construct the non-nil slice for the list. 893 arr := make([]interface{}, 0, len(lv.ListValue.Values)) 894 for _, v := range lv.ListValue.Values { 895 x, err := valForType(v, et) 896 if err != nil { 897 return nil, err 898 } 899 arr = append(arr, x) 900 } 901 return arr, nil 902 } 903 904 switch t.Base { 905 case spansql.Bool: 906 bv, ok := v.Kind.(*structpb.Value_BoolValue) 907 if ok { 908 return bv.BoolValue, nil 909 } 910 case spansql.Int64: 911 // The Spanner protocol encodes int64 as a decimal string. 912 sv, ok := v.Kind.(*structpb.Value_StringValue) 913 if ok { 914 x, err := strconv.ParseInt(sv.StringValue, 10, 64) 915 if err != nil { 916 return nil, fmt.Errorf("bad int64 string %q: %v", sv.StringValue, err) 917 } 918 return x, nil 919 } 920 case spansql.Float64: 921 nv, ok := v.Kind.(*structpb.Value_NumberValue) 922 if ok { 923 return nv.NumberValue, nil 924 } 925 case spansql.String: 926 sv, ok := v.Kind.(*structpb.Value_StringValue) 927 if ok { 928 return sv.StringValue, nil 929 } 930 case spansql.Bytes: 931 sv, ok := v.Kind.(*structpb.Value_StringValue) 932 if ok { 933 // The Spanner protocol encodes BYTES in base64. 934 return base64.StdEncoding.DecodeString(sv.StringValue) 935 } 936 case spansql.Date: 937 // The Spanner protocol encodes DATE in RFC 3339 date format. 938 sv, ok := v.Kind.(*structpb.Value_StringValue) 939 if ok { 940 s := sv.StringValue 941 d, err := parseAsDate(s) 942 if err != nil { 943 return nil, fmt.Errorf("bad DATE string %q: %v", s, err) 944 } 945 return d, nil 946 } 947 case spansql.Timestamp: 948 // The Spanner protocol encodes TIMESTAMP in RFC 3339 timestamp format with zone Z. 949 sv, ok := v.Kind.(*structpb.Value_StringValue) 950 if ok { 951 s := sv.StringValue 952 if strings.ToLower(s) == "spanner.commit_timestamp()" { 953 return commitTimestampSentinel, nil 954 } 955 t, err := parseAsTimestamp(s) 956 if err != nil { 957 return nil, fmt.Errorf("bad TIMESTAMP string %q: %v", s, err) 958 } 959 return t, nil 960 } 961 } 962 return nil, fmt.Errorf("unsupported inserting value kind %T into column of type %s", v.Kind, t.SQL()) 963} 964 965type keyRange struct { 966 start, end *structpb.ListValue 967 startClosed, endClosed bool 968 969 // These are populated during an operation 970 // when we know what table this keyRange applies to. 971 startKey, endKey []interface{} 972} 973 974func (r *keyRange) String() string { 975 var sb bytes.Buffer // TODO: Switch to strings.Builder when we drop support for Go 1.9. 976 if r.startClosed { 977 sb.WriteString("[") 978 } else { 979 sb.WriteString("(") 980 } 981 fmt.Fprintf(&sb, "%v,%v", r.startKey, r.endKey) 982 if r.endClosed { 983 sb.WriteString("]") 984 } else { 985 sb.WriteString(")") 986 } 987 return sb.String() 988} 989 990type keyRangeList []*keyRange 991 992// Execute runs a DML statement. 993// It returns the number of affected rows. 994func (d *database) Execute(stmt spansql.DMLStmt, params queryParams) (int, error) { // TODO: return *status.Status instead? 995 switch stmt := stmt.(type) { 996 default: 997 return 0, status.Errorf(codes.Unimplemented, "unhandled DML statement type %T", stmt) 998 case *spansql.Delete: 999 t, err := d.table(stmt.Table) 1000 if err != nil { 1001 return 0, err 1002 } 1003 1004 t.mu.Lock() 1005 defer t.mu.Unlock() 1006 1007 n := 0 1008 for i := 0; i < len(t.rows); { 1009 ec := evalContext{ 1010 cols: t.cols, 1011 row: t.rows[i], 1012 params: params, 1013 } 1014 b, err := ec.evalBoolExpr(stmt.Where) 1015 if err != nil { 1016 return 0, err 1017 } 1018 if b != nil && *b { 1019 copy(t.rows[i:], t.rows[i+1:]) 1020 t.rows = t.rows[:len(t.rows)-1] 1021 n++ 1022 continue 1023 } 1024 i++ 1025 } 1026 return n, nil 1027 case *spansql.Update: 1028 t, err := d.table(stmt.Table) 1029 if err != nil { 1030 return 0, err 1031 } 1032 1033 t.mu.Lock() 1034 defer t.mu.Unlock() 1035 1036 ec := evalContext{ 1037 cols: t.cols, 1038 params: params, 1039 } 1040 1041 // Build parallel slices of destination column index and expressions to evaluate. 1042 var dstIndex []int 1043 var expr []spansql.Expr 1044 for _, ui := range stmt.Items { 1045 i, err := ec.resolveColumnIndex(ui.Column) 1046 if err != nil { 1047 return 0, err 1048 } 1049 // TODO: Enforce "A column can appear only once in the SET clause.". 1050 if i < t.pkCols { 1051 return 0, status.Errorf(codes.InvalidArgument, "cannot update primary key %s", ui.Column) 1052 } 1053 dstIndex = append(dstIndex, i) 1054 expr = append(expr, ui.Value) 1055 } 1056 1057 n := 0 1058 values := make(row, len(stmt.Items)) // scratch space for new values 1059 for i := 0; i < len(t.rows); i++ { 1060 ec.row = t.rows[i] 1061 b, err := ec.evalBoolExpr(stmt.Where) 1062 if err != nil { 1063 return 0, err 1064 } 1065 if b != nil && *b { 1066 // Compute every update item. 1067 for j := range dstIndex { 1068 if expr[j] == nil { // DEFAULT 1069 values[j] = nil 1070 continue 1071 } 1072 v, err := ec.evalExpr(expr[j]) 1073 if err != nil { 1074 return 0, err 1075 } 1076 values[j] = v 1077 } 1078 // Write them to the row. 1079 for j, v := range values { 1080 t.rows[i][dstIndex[j]] = v 1081 } 1082 n++ 1083 } 1084 } 1085 return n, nil 1086 } 1087} 1088 1089func parseAsDate(s string) (civil.Date, error) { return civil.ParseDate(s) } 1090func parseAsTimestamp(s string) (time.Time, error) { 1091 return time.Parse("2006-01-02T15:04:05.999999999Z", s) 1092} 1093