1/* 2Copyright 2019 Google LLC 3 4Licensed under the Apache License, Version 2.0 (the "License"); 5you may not use this file except in compliance with the License. 6You may obtain a copy of the License at 7 8 http://www.apache.org/licenses/LICENSE-2.0 9 10Unless required by applicable law or agreed to in writing, software 11distributed under the License is distributed on an "AS IS" BASIS, 12WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13See the License for the specific language governing permissions and 14limitations under the License. 15*/ 16 17package spannertest 18 19// This file contains the implementation of the Spanner fake itself, 20// namely the part behind the RPC interface. 21 22// TODO: missing transactionality in a serious way! 23 24import ( 25 "bytes" 26 "encoding/base64" 27 "fmt" 28 "sort" 29 "strconv" 30 "strings" 31 "sync" 32 "time" 33 34 "google.golang.org/grpc/codes" 35 "google.golang.org/grpc/status" 36 37 structpb "github.com/golang/protobuf/ptypes/struct" 38 39 "cloud.google.com/go/civil" 40 "cloud.google.com/go/spanner/spansql" 41) 42 43type database struct { 44 mu sync.Mutex 45 lastTS time.Time // last commit timestamp 46 tables map[spansql.ID]*table 47 indexes map[spansql.ID]struct{} // only record their existence 48 49 rwMu sync.Mutex // held by read-write transactions 50} 51 52type table struct { 53 mu sync.Mutex 54 55 // Information about the table columns. 56 // They are reordered on table creation so the primary key columns come first. 57 cols []colInfo 58 colIndex map[spansql.ID]int // col name to index 59 origIndex map[spansql.ID]int // original index of each column upon construction 60 pkCols int // number of primary key columns (may be 0) 61 pkDesc []bool // whether each primary key column is in descending order 62 63 // Rows are stored in primary key order. 64 rows []row 65} 66 67// colInfo represents information about a column in a table or result set. 68type colInfo struct { 69 Name spansql.ID 70 Type spansql.Type 71 Generated spansql.Expr 72 NotNull bool // only set for table columns 73 AggIndex int // Index+1 of SELECT list for which this is an aggregate value. 74 Alias spansql.PathExp // an alternate name for this column (result sets only) 75} 76 77// commitTimestampSentinel is a sentinel value for TIMESTAMP fields with allow_commit_timestamp=true. 78// It is accepted, but never stored. 79var commitTimestampSentinel = &struct{}{} 80 81// transaction records information about a running transaction. 82// This is not safe for concurrent use. 83type transaction struct { 84 // readOnly is whether this transaction was constructed 85 // for read-only use, and should yield errors if used 86 // to perform a mutation. 87 readOnly bool 88 89 d *database 90 commitTimestamp time.Time // not set if readOnly 91 unlock func() // may be nil 92} 93 94func (d *database) NewReadOnlyTransaction() *transaction { 95 return &transaction{ 96 readOnly: true, 97 } 98} 99 100func (d *database) NewTransaction() *transaction { 101 return &transaction{ 102 d: d, 103 } 104} 105 106// Start starts the transaction and commits to a specific commit timestamp. 107// This also locks out any other read-write transaction on this database 108// until Commit/Rollback are called. 109func (tx *transaction) Start() { 110 // Commit timestamps are only guaranteed to be unique 111 // when transactions write to overlapping sets of fields. 112 // This simulated database exceeds that guarantee. 113 114 // Grab rwMu for the duration of this transaction. 115 // Take it before d.mu so we don't hold that lock 116 // while waiting for d.rwMu, which is held for longer. 117 tx.d.rwMu.Lock() 118 119 tx.d.mu.Lock() 120 const tsRes = 1 * time.Microsecond 121 now := time.Now().UTC().Truncate(tsRes) 122 if !now.After(tx.d.lastTS) { 123 now = tx.d.lastTS.Add(tsRes) 124 } 125 tx.d.lastTS = now 126 tx.d.mu.Unlock() 127 128 tx.commitTimestamp = now 129 tx.unlock = tx.d.rwMu.Unlock 130} 131 132func (tx *transaction) checkMutable() error { 133 if tx.readOnly { 134 // TODO: is this the right status? 135 return status.Errorf(codes.InvalidArgument, "transaction is read-only") 136 } 137 return nil 138} 139 140func (tx *transaction) Commit() (time.Time, error) { 141 if tx.unlock != nil { 142 tx.unlock() 143 } 144 return tx.commitTimestamp, nil 145} 146 147func (tx *transaction) Rollback() { 148 if tx.unlock != nil { 149 tx.unlock() 150 } 151 // TODO: actually rollback 152} 153 154/* 155row represents a list of data elements. 156 157The mapping between Spanner types and Go types internal to this package are: 158 BOOL bool 159 INT64 int64 160 FLOAT64 float64 161 STRING string 162 BYTES []byte 163 DATE civil.Date 164 TIMESTAMP time.Time (location set to UTC) 165 ARRAY<T> []interface{} 166 STRUCT TODO 167*/ 168type row []interface{} 169 170func (r row) copyDataElem(index int) interface{} { 171 v := r[index] 172 if is, ok := v.([]interface{}); ok { 173 // Deep-copy array values. 174 v = append([]interface{}(nil), is...) 175 } 176 return v 177} 178 179// copyData returns a copy of the row. 180func (r row) copyAllData() row { 181 dst := make(row, 0, len(r)) 182 for i := range r { 183 dst = append(dst, r.copyDataElem(i)) 184 } 185 return dst 186} 187 188// copyData returns a copy of a subset of a row. 189func (r row) copyData(indexes []int) row { 190 if len(indexes) == 0 { 191 return nil 192 } 193 dst := make(row, 0, len(indexes)) 194 for _, i := range indexes { 195 dst = append(dst, r.copyDataElem(i)) 196 } 197 return dst 198} 199 200func (d *database) LastCommitTimestamp() time.Time { 201 d.mu.Lock() 202 defer d.mu.Unlock() 203 return d.lastTS 204} 205 206func (d *database) GetDDL() []spansql.DDLStmt { 207 // This lacks fidelity, but captures the details we support. 208 d.mu.Lock() 209 defer d.mu.Unlock() 210 211 var stmts []spansql.DDLStmt 212 213 for name, t := range d.tables { 214 ct := &spansql.CreateTable{ 215 Name: name, 216 } 217 218 t.mu.Lock() 219 for i, col := range t.cols { 220 ct.Columns = append(ct.Columns, spansql.ColumnDef{ 221 Name: col.Name, 222 Type: col.Type, 223 NotNull: col.NotNull, 224 // TODO: AllowCommitTimestamp 225 }) 226 if i < t.pkCols { 227 ct.PrimaryKey = append(ct.PrimaryKey, spansql.KeyPart{ 228 Column: col.Name, 229 Desc: t.pkDesc[i], 230 }) 231 } 232 } 233 t.mu.Unlock() 234 235 stmts = append(stmts, ct) 236 } 237 238 return stmts 239} 240 241func (d *database) ApplyDDL(stmt spansql.DDLStmt) *status.Status { 242 d.mu.Lock() 243 defer d.mu.Unlock() 244 245 // Lazy init. 246 if d.tables == nil { 247 d.tables = make(map[spansql.ID]*table) 248 } 249 if d.indexes == nil { 250 d.indexes = make(map[spansql.ID]struct{}) 251 } 252 253 switch stmt := stmt.(type) { 254 default: 255 return status.Newf(codes.Unimplemented, "unhandled DDL statement type %T", stmt) 256 case *spansql.CreateTable: 257 if _, ok := d.tables[stmt.Name]; ok { 258 return status.Newf(codes.AlreadyExists, "table %s already exists", stmt.Name) 259 } 260 if len(stmt.PrimaryKey) == 0 { 261 return status.Newf(codes.InvalidArgument, "table %s has no primary key", stmt.Name) 262 } 263 264 // TODO: check stmt.Interleave details. 265 266 // Record original column ordering. 267 orig := make(map[spansql.ID]int) 268 for i, col := range stmt.Columns { 269 orig[col.Name] = i 270 } 271 272 // Move primary keys first, preserving their order. 273 pk := make(map[spansql.ID]int) 274 var pkDesc []bool 275 for i, kp := range stmt.PrimaryKey { 276 pk[kp.Column] = -1000 + i 277 pkDesc = append(pkDesc, kp.Desc) 278 } 279 sort.SliceStable(stmt.Columns, func(i, j int) bool { 280 a, b := pk[stmt.Columns[i].Name], pk[stmt.Columns[j].Name] 281 return a < b 282 }) 283 284 t := &table{ 285 colIndex: make(map[spansql.ID]int), 286 origIndex: orig, 287 pkCols: len(pk), 288 pkDesc: pkDesc, 289 } 290 for _, cd := range stmt.Columns { 291 if st := t.addColumn(cd, true); st.Code() != codes.OK { 292 return st 293 } 294 } 295 for col := range pk { 296 if _, ok := t.colIndex[col]; !ok { 297 return status.Newf(codes.InvalidArgument, "primary key column %q not in table", col) 298 } 299 } 300 d.tables[stmt.Name] = t 301 return nil 302 case *spansql.CreateIndex: 303 if _, ok := d.indexes[stmt.Name]; ok { 304 return status.Newf(codes.AlreadyExists, "index %s already exists", stmt.Name) 305 } 306 d.indexes[stmt.Name] = struct{}{} 307 return nil 308 case *spansql.DropTable: 309 if _, ok := d.tables[stmt.Name]; !ok { 310 return status.Newf(codes.NotFound, "no table named %s", stmt.Name) 311 } 312 // TODO: check for indexes on this table. 313 delete(d.tables, stmt.Name) 314 return nil 315 case *spansql.DropIndex: 316 if _, ok := d.indexes[stmt.Name]; !ok { 317 return status.Newf(codes.NotFound, "no index named %s", stmt.Name) 318 } 319 delete(d.indexes, stmt.Name) 320 return nil 321 case *spansql.AlterTable: 322 t, ok := d.tables[stmt.Name] 323 if !ok { 324 return status.Newf(codes.NotFound, "no table named %s", stmt.Name) 325 } 326 switch alt := stmt.Alteration.(type) { 327 default: 328 return status.Newf(codes.Unimplemented, "unhandled DDL table alteration type %T", alt) 329 case spansql.AddColumn: 330 if st := t.addColumn(alt.Def, false); st.Code() != codes.OK { 331 return st 332 } 333 return nil 334 case spansql.DropColumn: 335 if st := t.dropColumn(alt.Name); st.Code() != codes.OK { 336 return st 337 } 338 return nil 339 case spansql.AlterColumn: 340 if st := t.alterColumn(alt); st.Code() != codes.OK { 341 return st 342 } 343 return nil 344 } 345 } 346 347} 348 349func (d *database) table(tbl spansql.ID) (*table, error) { 350 d.mu.Lock() 351 defer d.mu.Unlock() 352 353 t, ok := d.tables[tbl] 354 if !ok { 355 return nil, status.Errorf(codes.NotFound, "no table named %s", tbl) 356 } 357 return t, nil 358} 359 360// writeValues executes a write option (Insert, Update, etc.). 361func (d *database) writeValues(tx *transaction, tbl spansql.ID, cols []spansql.ID, values []*structpb.ListValue, f func(t *table, colIndexes []int, r row) error) error { 362 if err := tx.checkMutable(); err != nil { 363 return err 364 } 365 366 t, err := d.table(tbl) 367 if err != nil { 368 return err 369 } 370 371 t.mu.Lock() 372 defer t.mu.Unlock() 373 374 colIndexes, err := t.colIndexes(cols) 375 if err != nil { 376 return err 377 } 378 revIndex := make(map[int]int) // table index to col index 379 for j, i := range colIndexes { 380 revIndex[i] = j 381 } 382 383 for pki := 0; pki < t.pkCols; pki++ { 384 _, ok := revIndex[pki] 385 if !ok { 386 return status.Errorf(codes.InvalidArgument, "primary key column %s not included in write", t.cols[pki].Name) 387 } 388 } 389 390 for _, vs := range values { 391 if len(vs.Values) != len(colIndexes) { 392 return status.Errorf(codes.InvalidArgument, "row of %d values can't be written to %d columns", len(vs.Values), len(colIndexes)) 393 } 394 395 r := make(row, len(t.cols)) 396 for j, v := range vs.Values { 397 i := colIndexes[j] 398 399 if t.cols[i].Generated != nil { 400 return status.Error(codes.InvalidArgument, "values can't be written to a generated column") 401 } 402 x, err := valForType(v, t.cols[i].Type) 403 if err != nil { 404 return err 405 } 406 if x == commitTimestampSentinel { 407 x = tx.commitTimestamp 408 } 409 if x == nil && t.cols[i].NotNull { 410 return status.Errorf(codes.FailedPrecondition, "%s must not be NULL in table %s", t.cols[i].Name, tbl) 411 } 412 413 r[i] = x 414 } 415 // TODO: enforce that provided timestamp for commit_timestamp=true columns 416 // are not ahead of the transaction's commit timestamp. 417 418 if err := f(t, colIndexes, r); err != nil { 419 return err 420 } 421 422 // Get row again after potential update merge to ensure we compute 423 // generated columns with fresh data. 424 pk := r[:t.pkCols] 425 rowNum, found := t.rowForPK(pk) 426 // This should never fail as the row was just inserted. 427 if !found { 428 return status.Error(codes.Internal, "row failed to be inserted") 429 } 430 row := t.rows[rowNum] 431 ec := evalContext{ 432 cols: t.cols, 433 row: row, 434 } 435 436 // TODO: We would need to do a topological sort on dependencies 437 // (i.e. what other columns the expression references) to ensure we 438 // can handle generated columns which reference other generated columns 439 for i, col := range t.cols { 440 if col.Generated != nil { 441 res, err := ec.evalExpr(col.Generated) 442 if err != nil { 443 return err 444 } 445 row[i] = res 446 } 447 } 448 } 449 450 return nil 451} 452 453func (d *database) Insert(tx *transaction, tbl spansql.ID, cols []spansql.ID, values []*structpb.ListValue) error { 454 return d.writeValues(tx, tbl, cols, values, func(t *table, colIndexes []int, r row) error { 455 pk := r[:t.pkCols] 456 rowNum, found := t.rowForPK(pk) 457 if found { 458 return status.Errorf(codes.AlreadyExists, "row already in table") 459 } 460 t.insertRow(rowNum, r) 461 return nil 462 }) 463} 464 465func (d *database) Update(tx *transaction, tbl spansql.ID, cols []spansql.ID, values []*structpb.ListValue) error { 466 return d.writeValues(tx, tbl, cols, values, func(t *table, colIndexes []int, r row) error { 467 if t.pkCols == 0 { 468 return status.Errorf(codes.InvalidArgument, "cannot update table %s with no columns in primary key", tbl) 469 } 470 pk := r[:t.pkCols] 471 rowNum, found := t.rowForPK(pk) 472 if !found { 473 // TODO: is this the right way to return `NOT_FOUND`? 474 return status.Errorf(codes.NotFound, "row not in table") 475 } 476 477 for _, i := range colIndexes { 478 t.rows[rowNum][i] = r[i] 479 } 480 return nil 481 }) 482} 483 484func (d *database) InsertOrUpdate(tx *transaction, tbl spansql.ID, cols []spansql.ID, values []*structpb.ListValue) error { 485 return d.writeValues(tx, tbl, cols, values, func(t *table, colIndexes []int, r row) error { 486 pk := r[:t.pkCols] 487 rowNum, found := t.rowForPK(pk) 488 if !found { 489 // New row; do an insert. 490 t.insertRow(rowNum, r) 491 } else { 492 // Existing row; do an update. 493 for _, i := range colIndexes { 494 t.rows[rowNum][i] = r[i] 495 } 496 } 497 return nil 498 }) 499} 500 501// TODO: Replace 502 503func (d *database) Delete(tx *transaction, table spansql.ID, keys []*structpb.ListValue, keyRanges keyRangeList, all bool) error { 504 if err := tx.checkMutable(); err != nil { 505 return err 506 } 507 508 t, err := d.table(table) 509 if err != nil { 510 return err 511 } 512 513 t.mu.Lock() 514 defer t.mu.Unlock() 515 516 if all { 517 t.rows = nil 518 return nil 519 } 520 521 for _, key := range keys { 522 pk, err := t.primaryKey(key.Values) 523 if err != nil { 524 return err 525 } 526 // Not an error if the key does not exist. 527 rowNum, found := t.rowForPK(pk) 528 if found { 529 copy(t.rows[rowNum:], t.rows[rowNum+1:]) 530 t.rows = t.rows[:len(t.rows)-1] 531 } 532 } 533 534 for _, r := range keyRanges { 535 r.startKey, err = t.primaryKeyPrefix(r.start.Values) 536 if err != nil { 537 return err 538 } 539 r.endKey, err = t.primaryKeyPrefix(r.end.Values) 540 if err != nil { 541 return err 542 } 543 startRow, endRow := t.findRange(r) 544 if n := endRow - startRow; n > 0 { 545 copy(t.rows[startRow:], t.rows[endRow:]) 546 t.rows = t.rows[:len(t.rows)-n] 547 } 548 } 549 550 return nil 551} 552 553// readTable executes a read option (Read, ReadAll). 554func (d *database) readTable(table spansql.ID, cols []spansql.ID, f func(*table, *rawIter, []int) error) (*rawIter, error) { 555 t, err := d.table(table) 556 if err != nil { 557 return nil, err 558 } 559 560 t.mu.Lock() 561 defer t.mu.Unlock() 562 563 colIndexes, err := t.colIndexes(cols) 564 if err != nil { 565 return nil, err 566 } 567 568 ri := &rawIter{} 569 for _, i := range colIndexes { 570 ri.cols = append(ri.cols, t.cols[i]) 571 } 572 return ri, f(t, ri, colIndexes) 573} 574 575func (d *database) Read(tbl spansql.ID, cols []spansql.ID, keys []*structpb.ListValue, keyRanges keyRangeList, limit int64) (rowIter, error) { 576 // The real Cloud Spanner returns an error if the key set is empty by definition. 577 // That doesn't seem to be well-defined, but it is a common error to attempt a read with no keys, 578 // so catch that here and return a representative error. 579 if len(keys) == 0 && len(keyRanges) == 0 { 580 return nil, status.Error(codes.Unimplemented, "Cloud Spanner does not support reading no keys") 581 } 582 583 return d.readTable(tbl, cols, func(t *table, ri *rawIter, colIndexes []int) error { 584 // "If the same key is specified multiple times in the set (for 585 // example if two ranges, two keys, or a key and a range 586 // overlap), Cloud Spanner behaves as if the key were only 587 // specified once." 588 done := make(map[int]bool) // row numbers we've included in ri. 589 590 // Specific keys. 591 for _, key := range keys { 592 pk, err := t.primaryKey(key.Values) 593 if err != nil { 594 return err 595 } 596 // Not an error if the key does not exist. 597 rowNum, found := t.rowForPK(pk) 598 if !found { 599 continue 600 } 601 if done[rowNum] { 602 continue 603 } 604 done[rowNum] = true 605 ri.add(t.rows[rowNum], colIndexes) 606 if limit > 0 && len(ri.rows) >= int(limit) { 607 return nil 608 } 609 } 610 611 // Key ranges. 612 for _, r := range keyRanges { 613 var err error 614 r.startKey, err = t.primaryKeyPrefix(r.start.Values) 615 if err != nil { 616 return err 617 } 618 r.endKey, err = t.primaryKeyPrefix(r.end.Values) 619 if err != nil { 620 return err 621 } 622 startRow, endRow := t.findRange(r) 623 for rowNum := startRow; rowNum < endRow; rowNum++ { 624 if done[rowNum] { 625 continue 626 } 627 done[rowNum] = true 628 ri.add(t.rows[rowNum], colIndexes) 629 if limit > 0 && len(ri.rows) >= int(limit) { 630 return nil 631 } 632 } 633 } 634 635 return nil 636 }) 637} 638 639func (d *database) ReadAll(tbl spansql.ID, cols []spansql.ID, limit int64) (*rawIter, error) { 640 return d.readTable(tbl, cols, func(t *table, ri *rawIter, colIndexes []int) error { 641 for _, r := range t.rows { 642 ri.add(r, colIndexes) 643 if limit > 0 && len(ri.rows) >= int(limit) { 644 break 645 } 646 } 647 return nil 648 }) 649} 650 651func (t *table) addColumn(cd spansql.ColumnDef, newTable bool) *status.Status { 652 if !newTable && cd.NotNull { 653 return status.Newf(codes.InvalidArgument, "new non-key columns cannot be NOT NULL") 654 } 655 656 if _, ok := t.colIndex[cd.Name]; ok { 657 return status.Newf(codes.AlreadyExists, "column %s already exists", cd.Name) 658 } 659 660 t.mu.Lock() 661 defer t.mu.Unlock() 662 663 if len(t.rows) > 0 { 664 if cd.NotNull { 665 // TODO: what happens in this case? 666 return status.Newf(codes.Unimplemented, "can't add NOT NULL columns to non-empty tables yet") 667 } 668 if cd.Generated != nil { 669 // TODO: should backfill the data to maintain behaviour with real spanner 670 return status.Newf(codes.Unimplemented, "can't add generated columns to non-empty tables yet") 671 } 672 for i := range t.rows { 673 t.rows[i] = append(t.rows[i], nil) 674 } 675 } 676 677 t.cols = append(t.cols, colInfo{ 678 Name: cd.Name, 679 Type: cd.Type, 680 NotNull: cd.NotNull, 681 // TODO: We should figure out what columns the Generator expression 682 // relies on and check it is valid at this time currently it will 683 // fail when writing data instead as it is the first time we 684 // evaluate the expression. 685 Generated: cd.Generated, 686 }) 687 t.colIndex[cd.Name] = len(t.cols) - 1 688 if !newTable { 689 t.origIndex[cd.Name] = len(t.cols) - 1 690 } 691 692 return nil 693} 694 695func (t *table) dropColumn(name spansql.ID) *status.Status { 696 // Only permit dropping non-key columns that aren't part of a secondary index. 697 // We don't support indexes, so only check that it isn't part of the primary key. 698 699 t.mu.Lock() 700 defer t.mu.Unlock() 701 702 ci, ok := t.colIndex[name] 703 if !ok { 704 // TODO: What's the right response code? 705 return status.Newf(codes.InvalidArgument, "unknown column %q", name) 706 } 707 if ci < t.pkCols { 708 // TODO: What's the right response code? 709 return status.Newf(codes.InvalidArgument, "can't drop primary key column %q", name) 710 } 711 712 // Remove from cols and colIndex, and renumber colIndex and origIndex. 713 t.cols = append(t.cols[:ci], t.cols[ci+1:]...) 714 delete(t.colIndex, name) 715 for i, col := range t.cols { 716 t.colIndex[col.Name] = i 717 } 718 pre := t.origIndex[name] 719 delete(t.origIndex, name) 720 for n, i := range t.origIndex { 721 if i > pre { 722 t.origIndex[n]-- 723 } 724 } 725 726 // Drop data. 727 for i := range t.rows { 728 t.rows[i] = append(t.rows[i][:ci], t.rows[i][ci+1:]...) 729 } 730 731 return nil 732} 733 734func (t *table) alterColumn(alt spansql.AlterColumn) *status.Status { 735 // Supported changes here are: 736 // Add NOT NULL to a non-key column, excluding ARRAY columns. 737 // Remove NOT NULL from a non-key column. 738 // Change a STRING column to a BYTES column or a BYTES column to a STRING column. 739 // Increase or decrease the length limit for a STRING or BYTES type (including to MAX). 740 // Enable or disable commit timestamps in value and primary key columns. 741 // https://cloud.google.com/spanner/docs/schema-updates#supported-updates 742 743 // TODO: codes.InvalidArgument is used throughout here for reporting errors, 744 // but that has not been validated against the real Spanner. 745 746 sct, ok := alt.Alteration.(spansql.SetColumnType) 747 if !ok { 748 return status.Newf(codes.InvalidArgument, "unsupported ALTER COLUMN %s", alt.SQL()) 749 } 750 751 t.mu.Lock() 752 defer t.mu.Unlock() 753 754 ci, ok := t.colIndex[alt.Name] 755 if !ok { 756 return status.Newf(codes.InvalidArgument, "unknown column %q", alt.Name) 757 } 758 759 oldT, newT := t.cols[ci].Type, sct.Type 760 stringOrBytes := func(bt spansql.TypeBase) bool { return bt == spansql.String || bt == spansql.Bytes } 761 762 // First phase: Check the validity of the change. 763 // TODO: Don't permit changes to allow commit timestamps. 764 if !t.cols[ci].NotNull && sct.NotNull { 765 // Adding NOT NULL is not permitted for primary key columns or array typed columns. 766 if ci < t.pkCols { 767 return status.Newf(codes.InvalidArgument, "cannot set NOT NULL on primary key column %q", alt.Name) 768 } 769 if oldT.Array { 770 return status.Newf(codes.InvalidArgument, "cannot set NOT NULL on array-typed column %q", alt.Name) 771 } 772 // Validate that there are no NULL values. 773 for _, row := range t.rows { 774 if row[ci] == nil { 775 return status.Newf(codes.InvalidArgument, "cannot set NOT NULL on column %q that contains NULL values", alt.Name) 776 } 777 } 778 } 779 var conv func(x interface{}) interface{} 780 if stringOrBytes(oldT.Base) && stringOrBytes(newT.Base) && !oldT.Array && !newT.Array { 781 // Change between STRING and BYTES is fine, as is increasing/decreasing the length limit. 782 // TODO: This should permit array conversions too. 783 // TODO: Validate data; length limit changes should be rejected if they'd lead to data loss, for instance. 784 if oldT.Base == spansql.Bytes && newT.Base == spansql.String { 785 conv = func(x interface{}) interface{} { return string(x.([]byte)) } 786 } else if oldT.Base == spansql.String && newT.Base == spansql.Bytes { 787 conv = func(x interface{}) interface{} { return []byte(x.(string)) } 788 } 789 } else if oldT == newT { 790 // Same type; only NOT NULL changes. 791 } else { // TODO: Support other alterations. 792 return status.Newf(codes.InvalidArgument, "unsupported ALTER COLUMN %s", alt.SQL()) 793 } 794 795 // Second phase: Make type transformations. 796 t.cols[ci].NotNull = sct.NotNull 797 t.cols[ci].Type = newT 798 if conv != nil { 799 for _, row := range t.rows { 800 if row[ci] != nil { // NULL stays as NULL. 801 row[ci] = conv(row[ci]) 802 } 803 } 804 } 805 return nil 806} 807 808func (t *table) insertRow(rowNum int, r row) { 809 t.rows = append(t.rows, nil) 810 copy(t.rows[rowNum+1:], t.rows[rowNum:]) 811 t.rows[rowNum] = r 812} 813 814// findRange finds the rows included in the key range, 815// reporting it as a half-open interval. 816// r.startKey and r.endKey should be populated. 817func (t *table) findRange(r *keyRange) (int, int) { 818 // startRow is the first row matching the range. 819 startRow := sort.Search(len(t.rows), func(i int) bool { 820 return rowCmp(r.startKey, t.rows[i][:t.pkCols], t.pkDesc) <= 0 821 }) 822 if startRow == len(t.rows) { 823 return startRow, startRow 824 } 825 if !r.startClosed && rowCmp(r.startKey, t.rows[startRow][:t.pkCols], t.pkDesc) == 0 { 826 startRow++ 827 } 828 829 // endRow is one more than the last row matching the range. 830 endRow := sort.Search(len(t.rows), func(i int) bool { 831 return rowCmp(r.endKey, t.rows[i][:t.pkCols], t.pkDesc) < 0 832 }) 833 if !r.endClosed && rowCmp(r.endKey, t.rows[endRow-1][:t.pkCols], t.pkDesc) == 0 { 834 endRow-- 835 } 836 837 return startRow, endRow 838} 839 840// colIndexes returns the indexes for the named columns. 841func (t *table) colIndexes(cols []spansql.ID) ([]int, error) { 842 var is []int 843 for _, col := range cols { 844 i, ok := t.colIndex[col] 845 if !ok { 846 return nil, status.Errorf(codes.InvalidArgument, "column %s not in table", col) 847 } 848 is = append(is, i) 849 } 850 return is, nil 851} 852 853// primaryKey constructs the internal representation of a primary key. 854// The list of given values must be in 1:1 correspondence with the primary key of the table. 855func (t *table) primaryKey(values []*structpb.Value) ([]interface{}, error) { 856 if len(values) != t.pkCols { 857 return nil, status.Errorf(codes.InvalidArgument, "primary key length mismatch: got %d values, table has %d", len(values), t.pkCols) 858 } 859 return t.primaryKeyPrefix(values) 860} 861 862// primaryKeyPrefix constructs the internal representation of a primary key prefix. 863func (t *table) primaryKeyPrefix(values []*structpb.Value) ([]interface{}, error) { 864 if len(values) > t.pkCols { 865 return nil, status.Errorf(codes.InvalidArgument, "primary key length too long: got %d values, table has %d", len(values), t.pkCols) 866 } 867 868 var pk []interface{} 869 for i, value := range values { 870 v, err := valForType(value, t.cols[i].Type) 871 if err != nil { 872 return nil, err 873 } 874 pk = append(pk, v) 875 } 876 return pk, nil 877} 878 879// rowForPK returns the index of t.rows that holds the row for the given primary key, and true. 880// If the given primary key isn't found, it returns the row that should hold it, and false. 881func (t *table) rowForPK(pk []interface{}) (row int, found bool) { 882 if len(pk) != t.pkCols { 883 panic(fmt.Sprintf("primary key length mismatch: got %d values, table has %d", len(pk), t.pkCols)) 884 } 885 886 i := sort.Search(len(t.rows), func(i int) bool { 887 return rowCmp(pk, t.rows[i][:t.pkCols], t.pkDesc) <= 0 888 }) 889 if i == len(t.rows) { 890 return i, false 891 } 892 return i, rowEqual(pk, t.rows[i][:t.pkCols]) 893} 894 895// rowCmp compares two rows, returning -1/0/+1. 896// The desc arg indicates whether each column is in a descending order. 897// This is used for primary key matching and so doesn't support array/struct types. 898// a is permitted to be shorter than b. 899func rowCmp(a, b []interface{}, desc []bool) int { 900 for i := 0; i < len(a); i++ { 901 if cmp := compareVals(a[i], b[i]); cmp != 0 { 902 if desc[i] { 903 cmp = -cmp 904 } 905 return cmp 906 } 907 } 908 return 0 909} 910 911// rowEqual reports whether two rows are equal. 912// This doesn't support array/struct types. 913func rowEqual(a, b []interface{}) bool { 914 for i := 0; i < len(a); i++ { 915 if compareVals(a[i], b[i]) != 0 { 916 return false 917 } 918 } 919 return true 920} 921 922// valForType converts a value from its RPC form into its internal representation. 923func valForType(v *structpb.Value, t spansql.Type) (interface{}, error) { 924 if _, ok := v.Kind.(*structpb.Value_NullValue); ok { 925 return nil, nil 926 } 927 928 if lv, ok := v.Kind.(*structpb.Value_ListValue); ok && t.Array { 929 et := t // element type 930 et.Array = false 931 932 // Construct the non-nil slice for the list. 933 arr := make([]interface{}, 0, len(lv.ListValue.Values)) 934 for _, v := range lv.ListValue.Values { 935 x, err := valForType(v, et) 936 if err != nil { 937 return nil, err 938 } 939 arr = append(arr, x) 940 } 941 return arr, nil 942 } 943 944 switch t.Base { 945 case spansql.Bool: 946 bv, ok := v.Kind.(*structpb.Value_BoolValue) 947 if ok { 948 return bv.BoolValue, nil 949 } 950 case spansql.Int64: 951 // The Spanner protocol encodes int64 as a decimal string. 952 sv, ok := v.Kind.(*structpb.Value_StringValue) 953 if ok { 954 x, err := strconv.ParseInt(sv.StringValue, 10, 64) 955 if err != nil { 956 return nil, fmt.Errorf("bad int64 string %q: %v", sv.StringValue, err) 957 } 958 return x, nil 959 } 960 case spansql.Float64: 961 nv, ok := v.Kind.(*structpb.Value_NumberValue) 962 if ok { 963 return nv.NumberValue, nil 964 } 965 case spansql.String: 966 sv, ok := v.Kind.(*structpb.Value_StringValue) 967 if ok { 968 return sv.StringValue, nil 969 } 970 case spansql.Bytes: 971 sv, ok := v.Kind.(*structpb.Value_StringValue) 972 if ok { 973 // The Spanner protocol encodes BYTES in base64. 974 return base64.StdEncoding.DecodeString(sv.StringValue) 975 } 976 case spansql.Date: 977 // The Spanner protocol encodes DATE in RFC 3339 date format. 978 sv, ok := v.Kind.(*structpb.Value_StringValue) 979 if ok { 980 s := sv.StringValue 981 d, err := parseAsDate(s) 982 if err != nil { 983 return nil, fmt.Errorf("bad DATE string %q: %v", s, err) 984 } 985 return d, nil 986 } 987 case spansql.Timestamp: 988 // The Spanner protocol encodes TIMESTAMP in RFC 3339 timestamp format with zone Z. 989 sv, ok := v.Kind.(*structpb.Value_StringValue) 990 if ok { 991 s := sv.StringValue 992 if strings.ToLower(s) == "spanner.commit_timestamp()" { 993 return commitTimestampSentinel, nil 994 } 995 t, err := parseAsTimestamp(s) 996 if err != nil { 997 return nil, fmt.Errorf("bad TIMESTAMP string %q: %v", s, err) 998 } 999 return t, nil 1000 } 1001 } 1002 return nil, fmt.Errorf("unsupported inserting value kind %T into column of type %s", v.Kind, t.SQL()) 1003} 1004 1005type keyRange struct { 1006 start, end *structpb.ListValue 1007 startClosed, endClosed bool 1008 1009 // These are populated during an operation 1010 // when we know what table this keyRange applies to. 1011 startKey, endKey []interface{} 1012} 1013 1014func (r *keyRange) String() string { 1015 var sb bytes.Buffer // TODO: Switch to strings.Builder when we drop support for Go 1.9. 1016 if r.startClosed { 1017 sb.WriteString("[") 1018 } else { 1019 sb.WriteString("(") 1020 } 1021 fmt.Fprintf(&sb, "%v,%v", r.startKey, r.endKey) 1022 if r.endClosed { 1023 sb.WriteString("]") 1024 } else { 1025 sb.WriteString(")") 1026 } 1027 return sb.String() 1028} 1029 1030type keyRangeList []*keyRange 1031 1032// Execute runs a DML statement. 1033// It returns the number of affected rows. 1034func (d *database) Execute(stmt spansql.DMLStmt, params queryParams) (int, error) { // TODO: return *status.Status instead? 1035 switch stmt := stmt.(type) { 1036 default: 1037 return 0, status.Errorf(codes.Unimplemented, "unhandled DML statement type %T", stmt) 1038 case *spansql.Delete: 1039 t, err := d.table(stmt.Table) 1040 if err != nil { 1041 return 0, err 1042 } 1043 1044 t.mu.Lock() 1045 defer t.mu.Unlock() 1046 1047 n := 0 1048 for i := 0; i < len(t.rows); { 1049 ec := evalContext{ 1050 cols: t.cols, 1051 row: t.rows[i], 1052 params: params, 1053 } 1054 b, err := ec.evalBoolExpr(stmt.Where) 1055 if err != nil { 1056 return 0, err 1057 } 1058 if b != nil && *b { 1059 copy(t.rows[i:], t.rows[i+1:]) 1060 t.rows = t.rows[:len(t.rows)-1] 1061 n++ 1062 continue 1063 } 1064 i++ 1065 } 1066 return n, nil 1067 case *spansql.Update: 1068 t, err := d.table(stmt.Table) 1069 if err != nil { 1070 return 0, err 1071 } 1072 1073 t.mu.Lock() 1074 defer t.mu.Unlock() 1075 1076 ec := evalContext{ 1077 cols: t.cols, 1078 params: params, 1079 } 1080 1081 // Build parallel slices of destination column index and expressions to evaluate. 1082 var dstIndex []int 1083 var expr []spansql.Expr 1084 for _, ui := range stmt.Items { 1085 i, err := ec.resolveColumnIndex(ui.Column) 1086 if err != nil { 1087 return 0, err 1088 } 1089 // TODO: Enforce "A column can appear only once in the SET clause.". 1090 if i < t.pkCols { 1091 return 0, status.Errorf(codes.InvalidArgument, "cannot update primary key %s", ui.Column) 1092 } 1093 dstIndex = append(dstIndex, i) 1094 expr = append(expr, ui.Value) 1095 } 1096 1097 n := 0 1098 values := make(row, len(stmt.Items)) // scratch space for new values 1099 for i := 0; i < len(t.rows); i++ { 1100 ec.row = t.rows[i] 1101 b, err := ec.evalBoolExpr(stmt.Where) 1102 if err != nil { 1103 return 0, err 1104 } 1105 if b != nil && *b { 1106 // Compute every update item. 1107 for j := range dstIndex { 1108 if expr[j] == nil { // DEFAULT 1109 values[j] = nil 1110 continue 1111 } 1112 v, err := ec.evalExpr(expr[j]) 1113 if err != nil { 1114 return 0, err 1115 } 1116 values[j] = v 1117 } 1118 // Write them to the row. 1119 for j, v := range values { 1120 t.rows[i][dstIndex[j]] = v 1121 } 1122 n++ 1123 } 1124 } 1125 return n, nil 1126 } 1127} 1128 1129func parseAsDate(s string) (civil.Date, error) { return civil.ParseDate(s) } 1130func parseAsTimestamp(s string) (time.Time, error) { 1131 return time.Parse("2006-01-02T15:04:05.999999999Z", s) 1132} 1133