1/* 2Copyright 2019 Google LLC 3 4Licensed under the Apache License, Version 2.0 (the "License"); 5you may not use this file except in compliance with the License. 6You may obtain a copy of the License at 7 8 http://www.apache.org/licenses/LICENSE-2.0 9 10Unless required by applicable law or agreed to in writing, software 11distributed under the License is distributed on an "AS IS" BASIS, 12WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13See the License for the specific language governing permissions and 14limitations under the License. 15*/ 16 17package spannertest 18 19// This file contains the implementation of the Spanner fake itself, 20// namely the part behind the RPC interface. 21 22// TODO: missing transactionality in a serious way! 23 24import ( 25 "bytes" 26 "encoding/base64" 27 "fmt" 28 "sort" 29 "strconv" 30 "sync" 31 "time" 32 33 "google.golang.org/grpc/codes" 34 "google.golang.org/grpc/status" 35 36 structpb "github.com/golang/protobuf/ptypes/struct" 37 38 "cloud.google.com/go/spanner/spansql" 39) 40 41type database struct { 42 mu sync.Mutex 43 tables map[string]*table 44 indexes map[string]struct{} // only record their existence 45} 46 47type table struct { 48 mu sync.Mutex 49 50 // Information about the table columns. 51 // They are reordered on table creation so the primary key columns come first. 52 cols []colInfo 53 colIndex map[string]int // col name to index 54 pkCols int // number of primary key columns (may be 0) 55 56 // Rows are stored in primary key order. 57 rows []row 58} 59 60// colInfo represents information about a column in a table or result set. 61type colInfo struct { 62 Name string 63 Type spansql.Type 64} 65 66/* 67row represents a list of data elements. 68 69The mapping between Spanner types and Go types internal to this package are: 70 BOOL bool 71 INT64 int64 72 FLOAT64 float64 73 STRING string 74 BYTES []byte 75 DATE string (RFC 3339 date; "YYYY-MM-DD") 76 TIMESTAMP TODO 77 ARRAY<T> []T 78 STRUCT TODO 79*/ 80type row []interface{} 81 82func (r row) copyDataElem(index int) interface{} { 83 v := r[index] 84 if is, ok := v.([]interface{}); ok { 85 // Deep-copy array values. 86 v = append([]interface{}(nil), is...) 87 } 88 return v 89} 90 91// copyData returns a copy of a subset of a row. 92func (r row) copyData(indexes []int) row { 93 if len(indexes) == 0 { 94 return nil 95 } 96 dst := make(row, 0, len(indexes)) 97 for _, i := range indexes { 98 dst = append(dst, r.copyDataElem(i)) 99 } 100 return dst 101} 102 103func (d *database) ApplyDDL(stmt spansql.DDLStmt) *status.Status { 104 d.mu.Lock() 105 defer d.mu.Unlock() 106 107 // Lazy init. 108 if d.tables == nil { 109 d.tables = make(map[string]*table) 110 } 111 if d.indexes == nil { 112 d.indexes = make(map[string]struct{}) 113 } 114 115 switch stmt := stmt.(type) { 116 default: 117 return status.Newf(codes.Unimplemented, "unhandled DDL statement type %T", stmt) 118 case spansql.CreateTable: 119 if _, ok := d.tables[stmt.Name]; ok { 120 return status.Newf(codes.AlreadyExists, "table %s already exists", stmt.Name) 121 } 122 123 // TODO: check stmt.Interleave details. 124 125 // Move primary keys first, preserving their order. 126 pk := make(map[string]int) 127 for i, kp := range stmt.PrimaryKey { 128 pk[kp.Column] = -1000 + i 129 } 130 sort.SliceStable(stmt.Columns, func(i, j int) bool { 131 a, b := pk[stmt.Columns[i].Name], pk[stmt.Columns[j].Name] 132 return a < b 133 }) 134 135 t := &table{ 136 colIndex: make(map[string]int), 137 pkCols: len(pk), 138 } 139 for _, cd := range stmt.Columns { 140 if st := t.addColumn(cd); st.Code() != codes.OK { 141 return st 142 } 143 } 144 for col := range pk { 145 if _, ok := t.colIndex[col]; !ok { 146 return status.Newf(codes.InvalidArgument, "primary key column %q not in table", col) 147 } 148 } 149 d.tables[stmt.Name] = t 150 return nil 151 case spansql.CreateIndex: 152 if _, ok := d.indexes[stmt.Name]; ok { 153 return status.Newf(codes.AlreadyExists, "index %s already exists", stmt.Name) 154 } 155 d.indexes[stmt.Name] = struct{}{} 156 return nil 157 case spansql.DropTable: 158 if _, ok := d.tables[stmt.Name]; !ok { 159 return status.Newf(codes.NotFound, "no table named %s", stmt.Name) 160 } 161 // TODO: check for indexes on this table. 162 delete(d.tables, stmt.Name) 163 return nil 164 case spansql.DropIndex: 165 if _, ok := d.indexes[stmt.Name]; !ok { 166 return status.Newf(codes.NotFound, "no index named %s", stmt.Name) 167 } 168 delete(d.indexes, stmt.Name) 169 return nil 170 case spansql.AlterTable: 171 t, ok := d.tables[stmt.Name] 172 if !ok { 173 return status.Newf(codes.NotFound, "no table named %s", stmt.Name) 174 } 175 switch alt := stmt.Alteration.(type) { 176 default: 177 return status.Newf(codes.Unimplemented, "unhandled DDL table alteration type %T", alt) 178 case spansql.AddColumn: 179 if alt.Def.NotNull { 180 return status.Newf(codes.InvalidArgument, "new non-key columns cannot be NOT NULL") 181 } 182 if st := t.addColumn(alt.Def); st.Code() != codes.OK { 183 return st 184 } 185 return nil 186 } 187 } 188 189} 190 191func (d *database) table(tbl string) (*table, error) { 192 d.mu.Lock() 193 defer d.mu.Unlock() 194 195 t, ok := d.tables[tbl] 196 if !ok { 197 return nil, status.Errorf(codes.NotFound, "no table named %s", tbl) 198 } 199 return t, nil 200} 201 202// writeValues executes a write option (Insert, Update, etc.). 203func (d *database) writeValues(tbl string, cols []string, values []*structpb.ListValue, f func(t *table, colIndexes []int, r row) error) error { 204 t, err := d.table(tbl) 205 if err != nil { 206 return err 207 } 208 209 t.mu.Lock() 210 defer t.mu.Unlock() 211 212 colIndexes, err := t.colIndexes(cols) 213 if err != nil { 214 return err 215 } 216 revIndex := make(map[int]int) // table index to col index 217 for j, i := range colIndexes { 218 revIndex[i] = j 219 } 220 221 for pki := 0; pki < t.pkCols; pki++ { 222 _, ok := revIndex[pki] 223 if !ok { 224 return status.Errorf(codes.InvalidArgument, "primary key column %s not included in write", t.cols[pki].Name) 225 } 226 } 227 228 for _, vs := range values { 229 if len(vs.Values) != len(colIndexes) { 230 return status.Errorf(codes.InvalidArgument, "row of %d values can't be written to %d columns", len(vs.Values), len(colIndexes)) 231 } 232 233 r := make(row, len(t.cols)) 234 for j, v := range vs.Values { 235 i := colIndexes[j] 236 237 x, err := valForType(v, t.cols[i].Type) 238 if err != nil { 239 return err 240 } 241 242 r[i] = x 243 } 244 // TODO: enforce NOT NULL? 245 246 if err := f(t, colIndexes, r); err != nil { 247 return err 248 } 249 } 250 251 return nil 252} 253 254func (d *database) Insert(tbl string, cols []string, values []*structpb.ListValue) error { 255 return d.writeValues(tbl, cols, values, func(t *table, colIndexes []int, r row) error { 256 pk := r[:t.pkCols] 257 rowNum, found := t.rowForPK(pk) 258 if found { 259 return status.Errorf(codes.AlreadyExists, "row already in table") 260 } 261 t.insertRow(rowNum, r) 262 return nil 263 }) 264} 265 266func (d *database) Update(tbl string, cols []string, values []*structpb.ListValue) error { 267 return d.writeValues(tbl, cols, values, func(t *table, colIndexes []int, r row) error { 268 if t.pkCols == 0 { 269 return status.Errorf(codes.InvalidArgument, "cannot update table %s with no columns in primary key", tbl) 270 } 271 pk := r[:t.pkCols] 272 rowNum, found := t.rowForPK(pk) 273 if !found { 274 // TODO: is this the right way to return `NOT_FOUND`? 275 return status.Errorf(codes.NotFound, "row not in table") 276 } 277 278 for _, i := range colIndexes { 279 t.rows[rowNum][i] = r[i] 280 } 281 return nil 282 }) 283} 284 285func (d *database) InsertOrUpdate(tbl string, cols []string, values []*structpb.ListValue) error { 286 return d.writeValues(tbl, cols, values, func(t *table, colIndexes []int, r row) error { 287 pk := r[:t.pkCols] 288 rowNum, found := t.rowForPK(pk) 289 if !found { 290 // New row; do an insert. 291 t.insertRow(rowNum, r) 292 } else { 293 // Existing row; do an update. 294 for _, i := range colIndexes { 295 t.rows[rowNum][i] = r[i] 296 } 297 } 298 return nil 299 }) 300} 301 302// TODO: Replace 303 304func (d *database) Delete(table string, keys []*structpb.ListValue, keyRanges keyRangeList, all bool) error { 305 t, err := d.table(table) 306 if err != nil { 307 return err 308 } 309 310 t.mu.Lock() 311 defer t.mu.Unlock() 312 313 if all { 314 t.rows = nil 315 return nil 316 } 317 318 for _, key := range keys { 319 pk, err := t.primaryKey(key.Values) 320 if err != nil { 321 return err 322 } 323 // Not an error if the key does not exist. 324 rowNum, found := t.rowForPK(pk) 325 if found { 326 copy(t.rows[rowNum:], t.rows[rowNum+1:]) 327 t.rows = t.rows[:len(t.rows)-1] 328 } 329 } 330 331 for _, r := range keyRanges { 332 r.startKey, err = t.primaryKeyPrefix(r.start.Values) 333 if err != nil { 334 return err 335 } 336 r.endKey, err = t.primaryKeyPrefix(r.end.Values) 337 if err != nil { 338 return err 339 } 340 startRow, endRow := t.findRange(r) 341 if n := endRow - startRow; n > 0 { 342 copy(t.rows[startRow:], t.rows[endRow:]) 343 t.rows = t.rows[:len(t.rows)-n] 344 } 345 } 346 347 return nil 348} 349 350// resultIter is returned by reads and queries. 351// Use its Next method to iterate over the result rows. 352type resultIter struct { 353 // Cols is the metadata about the returned data. 354 Cols []colInfo 355 356 // rows holds the result data itself. 357 rows []resultRow 358} 359 360type resultRow struct { 361 data []interface{} 362 363 // aux is any auxiliary values evaluated for the row. 364 // When a query has an ORDER BY clause, this will contain the values for those expressions. 365 aux []interface{} 366} 367 368func (ri *resultIter) Next() ([]interface{}, bool) { 369 if len(ri.rows) == 0 { 370 return nil, false 371 } 372 res := ri.rows[0] 373 ri.rows = ri.rows[1:] 374 return res.data, true 375} 376 377func (ri *resultIter) add(src row, colIndexes []int) { 378 ri.rows = append(ri.rows, resultRow{ 379 data: src.copyData(colIndexes), 380 }) 381} 382 383// readTable executes a read option (Read, ReadAll). 384func (d *database) readTable(table string, cols []string, f func(*table, *resultIter, []int) error) (*resultIter, error) { 385 t, err := d.table(table) 386 if err != nil { 387 return nil, err 388 } 389 390 t.mu.Lock() 391 defer t.mu.Unlock() 392 393 colIndexes, err := t.colIndexes(cols) 394 if err != nil { 395 return nil, err 396 } 397 398 ri := &resultIter{} 399 for _, i := range colIndexes { 400 ri.Cols = append(ri.Cols, t.cols[i]) 401 } 402 return ri, f(t, ri, colIndexes) 403} 404 405func (d *database) Read(tbl string, cols []string, keys []*structpb.ListValue, limit int64) (*resultIter, error) { 406 return d.readTable(tbl, cols, func(t *table, ri *resultIter, colIndexes []int) error { 407 for _, key := range keys { 408 pk, err := t.primaryKey(key.Values) 409 if err != nil { 410 return err 411 } 412 // Not an error if the key does not exist. 413 rowNum, found := t.rowForPK(pk) 414 if !found { 415 continue 416 } 417 ri.add(t.rows[rowNum], colIndexes) 418 if limit > 0 && len(ri.rows) >= int(limit) { 419 break 420 } 421 } 422 return nil 423 }) 424} 425 426func (d *database) ReadAll(tbl string, cols []string, limit int64) (*resultIter, error) { 427 return d.readTable(tbl, cols, func(t *table, ri *resultIter, colIndexes []int) error { 428 for _, r := range t.rows { 429 ri.add(r, colIndexes) 430 if limit > 0 && len(ri.rows) >= int(limit) { 431 break 432 } 433 } 434 return nil 435 }) 436} 437 438type queryParams map[string]interface{} 439 440func (d *database) Query(q spansql.Query, params queryParams) (*resultIter, error) { 441 // If there's an ORDER BY clause, prepare the list of auxiliary data we need. 442 // This is provided to evalSelect to evaluate with each row. 443 var aux []spansql.Expr 444 var desc []bool 445 if len(q.Order) > 0 { 446 if len(q.Select.From) == 0 { 447 return nil, fmt.Errorf("ORDER BY doesn't work without a table") 448 } 449 450 for _, o := range q.Order { 451 aux = append(aux, o.Expr) 452 desc = append(desc, o.Desc) 453 } 454 } 455 456 ri, err := d.evalSelect(q.Select, params, aux) 457 if err != nil { 458 return nil, err 459 } 460 if len(q.Order) > 0 { 461 sort.Slice(ri.rows, func(one, two int) bool { 462 r1, r2 := ri.rows[one], ri.rows[two] 463 for i := range r1.aux { 464 cmp := compareVals(r1.aux[i], r2.aux[i]) 465 if desc[i] { 466 cmp = -cmp 467 } 468 if cmp == 0 { 469 continue 470 } 471 return cmp < 0 472 } 473 return false 474 }) 475 } 476 if q.Limit != nil { 477 lim, err := evalLimit(q.Limit, params) 478 if err != nil { 479 return nil, err 480 } 481 if n := int(lim); n < len(ri.rows) { 482 ri.rows = ri.rows[:n] 483 } 484 } 485 return ri, nil 486} 487 488func (t *table) addColumn(cd spansql.ColumnDef) *status.Status { 489 t.mu.Lock() 490 defer t.mu.Unlock() 491 492 if len(t.rows) > 0 { 493 if cd.NotNull { 494 // TODO: what happens in this case? 495 return status.Newf(codes.Unimplemented, "can't add NOT NULL columns to non-empty tables yet") 496 } 497 for i := range t.rows { 498 t.rows[i] = append(t.rows[i], nil) 499 } 500 } 501 502 t.cols = append(t.cols, colInfo{ 503 Name: cd.Name, 504 Type: cd.Type, 505 }) 506 t.colIndex[cd.Name] = len(t.cols) - 1 507 508 return nil 509} 510 511func (t *table) insertRow(rowNum int, r row) { 512 t.rows = append(t.rows, nil) 513 copy(t.rows[rowNum+1:], t.rows[rowNum:]) 514 t.rows[rowNum] = r 515} 516 517// findRange finds the rows included in the key range, 518// reporting it as a half-open interval. 519// r.startKey and r.endKey should be populated. 520func (t *table) findRange(r *keyRange) (int, int) { 521 // TODO: This is incorrect for primary keys with descending order. 522 // It might be sufficient for the caller to switch start/end in that case. 523 524 // startRow is the first row matching the range. 525 startRow := sort.Search(len(t.rows), func(i int) bool { 526 return rowCmp(r.startKey, t.rows[i][:t.pkCols]) <= 0 527 }) 528 if startRow == len(t.rows) { 529 return startRow, startRow 530 } 531 if !r.startClosed && rowCmp(r.startKey, t.rows[startRow][:t.pkCols]) == 0 { 532 startRow++ 533 } 534 535 // endRow is one more than the last row matching the range. 536 endRow := sort.Search(len(t.rows), func(i int) bool { 537 return rowCmp(r.endKey, t.rows[i][:t.pkCols]) < 0 538 }) 539 if !r.endClosed && rowCmp(r.endKey, t.rows[endRow-1][:t.pkCols]) == 0 { 540 endRow-- 541 } 542 543 return startRow, endRow 544} 545 546// colIndexes returns the indexes for the named columns. 547func (t *table) colIndexes(cols []string) ([]int, error) { 548 var is []int 549 for _, col := range cols { 550 i, ok := t.colIndex[col] 551 if !ok { 552 return nil, status.Errorf(codes.InvalidArgument, "column %s not in table", col) 553 } 554 is = append(is, i) 555 } 556 return is, nil 557} 558 559// primaryKey constructs the internal representation of a primary key. 560// The list of given values must be in 1:1 correspondence with the primary key of the table. 561func (t *table) primaryKey(values []*structpb.Value) ([]interface{}, error) { 562 if len(values) != t.pkCols { 563 return nil, status.Errorf(codes.InvalidArgument, "primary key length mismatch: got %d values, table has %d", len(values), t.pkCols) 564 } 565 return t.primaryKeyPrefix(values) 566} 567 568// primaryKeyPrefix constructs the internal representation of a primary key prefix. 569func (t *table) primaryKeyPrefix(values []*structpb.Value) ([]interface{}, error) { 570 if len(values) > t.pkCols { 571 return nil, status.Errorf(codes.InvalidArgument, "primary key length too long: got %d values, table has %d", len(values), t.pkCols) 572 } 573 574 var pk []interface{} 575 for i, value := range values { 576 v, err := valForType(value, t.cols[i].Type) 577 if err != nil { 578 return nil, err 579 } 580 pk = append(pk, v) 581 } 582 return pk, nil 583} 584 585// rowForPK returns the index of t.rows that holds the row for the given primary key, and true. 586// If the given primary key isn't found, it returns the row that should hold it, and false. 587func (t *table) rowForPK(pk []interface{}) (row int, found bool) { 588 if len(pk) != t.pkCols { 589 panic(fmt.Sprintf("primary key length mismatch: got %d values, table has %d", len(pk), t.pkCols)) 590 } 591 592 i := sort.Search(len(t.rows), func(i int) bool { 593 return rowCmp(pk, t.rows[i][:t.pkCols]) <= 0 594 }) 595 if i == len(t.rows) { 596 return i, false 597 } 598 return i, rowCmp(pk, t.rows[i][:t.pkCols]) == 0 599} 600 601// rowCmp compares two rows, returning -1/0/+1. 602// This is used for primary key matching and so doesn't support array/struct types. 603// a is permitted to be shorter than b. 604func rowCmp(a, b []interface{}) int { 605 for i := 0; i < len(a); i++ { 606 if cmp := compareVals(a[i], b[i]); cmp != 0 { 607 return cmp 608 } 609 } 610 return 0 611} 612 613func valForType(v *structpb.Value, t spansql.Type) (interface{}, error) { 614 if _, ok := v.Kind.(*structpb.Value_NullValue); ok { 615 // TODO: enforce NOT NULL constraints? 616 return nil, nil 617 } 618 619 if lv, ok := v.Kind.(*structpb.Value_ListValue); ok && t.Array { 620 et := t // element type 621 et.Array = false 622 623 // Construct the non-nil slice for the list. 624 arr := make([]interface{}, 0, len(lv.ListValue.Values)) 625 for _, v := range lv.ListValue.Values { 626 x, err := valForType(v, et) 627 if err != nil { 628 return nil, err 629 } 630 arr = append(arr, x) 631 } 632 return arr, nil 633 } 634 635 switch t.Base { 636 case spansql.Bool: 637 bv, ok := v.Kind.(*structpb.Value_BoolValue) 638 if ok { 639 return bv.BoolValue, nil 640 } 641 case spansql.Int64: 642 // The Spanner protocol encodes int64 as a decimal string. 643 sv, ok := v.Kind.(*structpb.Value_StringValue) 644 if ok { 645 x, err := strconv.ParseInt(sv.StringValue, 10, 64) 646 if err != nil { 647 return nil, fmt.Errorf("bad int64 string %q: %v", sv.StringValue, err) 648 } 649 return x, nil 650 } 651 case spansql.Float64: 652 nv, ok := v.Kind.(*structpb.Value_NumberValue) 653 if ok { 654 return nv.NumberValue, nil 655 } 656 case spansql.String: 657 sv, ok := v.Kind.(*structpb.Value_StringValue) 658 if ok { 659 return sv.StringValue, nil 660 } 661 case spansql.Bytes: 662 sv, ok := v.Kind.(*structpb.Value_StringValue) 663 if ok { 664 // The Spanner protocol encodes BYTES in base64. 665 return base64.StdEncoding.DecodeString(sv.StringValue) 666 } 667 case spansql.Date: 668 // The Spanner protocol encodes DATE in RFC 3339 date format. 669 sv, ok := v.Kind.(*structpb.Value_StringValue) 670 if ok { 671 // Store it internally as a string, but validate its value. 672 s := sv.StringValue 673 if _, err := time.Parse("2006-01-02", s); err != nil { 674 return nil, fmt.Errorf("bad DATE string %q: %v", s, err) 675 } 676 return s, nil 677 } 678 } 679 return nil, fmt.Errorf("unsupported inserting value kind %T into column of type %s", v.Kind, t.SQL()) 680} 681 682type keyRange struct { 683 start, end *structpb.ListValue 684 startClosed, endClosed bool 685 686 // These are populated during an operation 687 // when we know what table this keyRange applies to. 688 startKey, endKey []interface{} 689} 690 691func (r *keyRange) String() string { 692 var sb bytes.Buffer // TODO: Switch to strings.Builder when we drop support for Go 1.9. 693 if r.startClosed { 694 sb.WriteString("[") 695 } else { 696 sb.WriteString("(") 697 } 698 fmt.Fprintf(&sb, "%v,%v", r.startKey, r.endKey) 699 if r.endClosed { 700 sb.WriteString("]") 701 } else { 702 sb.WriteString(")") 703 } 704 return sb.String() 705} 706 707type keyRangeList []*keyRange 708