1/* 2Copyright 2020 Google LLC 3 4Licensed under the Apache License, Version 2.0 (the "License"); 5you may not use this file except in compliance with the License. 6You may obtain a copy of the License at 7 8 http://www.apache.org/licenses/LICENSE-2.0 9 10Unless required by applicable law or agreed to in writing, software 11distributed under the License is distributed on an "AS IS" BASIS, 12WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13See the License for the specific language governing permissions and 14limitations under the License. 15*/ 16 17package spannertest 18 19import ( 20 "fmt" 21 "io" 22 "sort" 23 24 "cloud.google.com/go/spanner/spansql" 25) 26 27/* 28There's several ways to conceptualise SQL queries. The simplest, and what 29we implement here, is a series of pipelines that transform the data, whether 30pulling from a table (FROM tbl), filtering (WHERE expr), re-ordering (ORDER BY expr) 31or other transformations. 32 33The order of operations among those supported by Cloud Spanner is 34 FROM + JOIN + set ops [TODO: set ops] 35 WHERE 36 GROUP BY 37 aggregation 38 HAVING [TODO] 39 SELECT 40 DISTINCT 41 ORDER BY 42 OFFSET 43 LIMIT 44*/ 45 46// rowIter represents some iteration over rows of data. 47// It is returned by reads and queries. 48type rowIter interface { 49 // Cols returns the metadata about the returned data. 50 Cols() []colInfo 51 52 // Next returns the next row. 53 // If done, it returns (nil, io.EOF). 54 Next() (row, error) 55} 56 57// aggSentinel is a synthetic expression that refers to an aggregated value. 58// It is transient only; it is never stored and only used during evaluation. 59type aggSentinel struct { 60 spansql.Expr 61 Type spansql.Type 62 AggIndex int // Index+1 of SELECT list. 63} 64 65// nullIter is a rowIter that returns one empty row only. 66// This is used for queries without a table. 67type nullIter struct { 68 done bool 69} 70 71func (ni *nullIter) Cols() []colInfo { return nil } 72func (ni *nullIter) Next() (row, error) { 73 if ni.done { 74 return nil, io.EOF 75 } 76 ni.done = true 77 return nil, nil 78} 79 80// tableIter is a rowIter that walks a table. 81// It assumes the table is locked for the duration. 82type tableIter struct { 83 t *table 84 rowIndex int // index of next row to return 85 86 alias spansql.ID // if non-empty, "AS <alias>" 87} 88 89func (ti *tableIter) Cols() []colInfo { 90 // Build colInfo in the original column order. 91 cis := make([]colInfo, len(ti.t.cols)) 92 for _, ci := range ti.t.cols { 93 if ti.alias != "" { 94 ci.Alias = spansql.PathExp{ti.alias, ci.Name} 95 } 96 cis[ti.t.origIndex[ci.Name]] = ci 97 } 98 return cis 99} 100 101func (ti *tableIter) Next() (row, error) { 102 if ti.rowIndex >= len(ti.t.rows) { 103 return nil, io.EOF 104 } 105 r := ti.t.rows[ti.rowIndex] 106 ti.rowIndex++ 107 108 // Build output row in the original column order. 109 res := make(row, len(r)) 110 for i, ci := range ti.t.cols { 111 res[ti.t.origIndex[ci.Name]] = r[i] 112 } 113 114 return res, nil 115} 116 117// rawIter is a rowIter with fixed data. 118type rawIter struct { 119 // cols is the metadata about the returned data. 120 cols []colInfo 121 122 // rows holds the result data itself. 123 rows []row 124} 125 126func (raw *rawIter) Cols() []colInfo { return raw.cols } 127func (raw *rawIter) Next() (row, error) { 128 if len(raw.rows) == 0 { 129 return nil, io.EOF 130 } 131 res := raw.rows[0] 132 raw.rows = raw.rows[1:] 133 return res, nil 134} 135 136func (raw *rawIter) add(src row, colIndexes []int) { 137 raw.rows = append(raw.rows, src.copyData(colIndexes)) 138} 139 140// clone makes a shallow copy. 141func (raw *rawIter) clone() *rawIter { 142 return &rawIter{cols: raw.cols, rows: raw.rows} 143} 144 145func toRawIter(ri rowIter) (*rawIter, error) { 146 if raw, ok := ri.(*rawIter); ok { 147 return raw, nil 148 } 149 raw := &rawIter{cols: ri.Cols()} 150 for { 151 row, err := ri.Next() 152 if err == io.EOF { 153 break 154 } else if err != nil { 155 return nil, err 156 } 157 raw.rows = append(raw.rows, row.copyAllData()) 158 } 159 return raw, nil 160} 161 162// whereIter applies a WHERE clause. 163type whereIter struct { 164 ri rowIter 165 ec evalContext 166 where spansql.BoolExpr 167} 168 169func (wi whereIter) Cols() []colInfo { return wi.ri.Cols() } 170func (wi whereIter) Next() (row, error) { 171 for { 172 row, err := wi.ri.Next() 173 if err != nil { 174 return nil, err 175 } 176 wi.ec.row = row 177 178 b, err := wi.ec.evalBoolExpr(wi.where) 179 if err != nil { 180 return nil, err 181 } 182 if b != nil && *b { 183 return row, nil 184 } 185 } 186} 187 188// selIter applies a SELECT list. 189type selIter struct { 190 ri rowIter 191 ec evalContext 192 cis []colInfo 193 list []spansql.Expr 194 195 distinct bool // whether this is a SELECT DISTINCT 196 seen []row 197} 198 199func (si *selIter) Cols() []colInfo { return si.cis } 200func (si *selIter) Next() (row, error) { 201 for { 202 r, err := si.next() 203 if err != nil { 204 return nil, err 205 } 206 if si.distinct && !si.keep(r) { 207 continue 208 } 209 return r, nil 210 } 211} 212 213// next retrieves the next row for the SELECT and evaluates its expression list. 214func (si *selIter) next() (row, error) { 215 r, err := si.ri.Next() 216 if err != nil { 217 return nil, err 218 } 219 si.ec.row = r 220 221 var out row 222 for _, e := range si.list { 223 if e == spansql.Star { 224 out = append(out, r...) 225 } else { 226 v, err := si.ec.evalExpr(e) 227 if err != nil { 228 return nil, err 229 } 230 out = append(out, v) 231 } 232 } 233 return out, nil 234} 235 236func (si *selIter) keep(r row) bool { 237 // This is hilariously inefficient; O(N^2) in the number of returned rows. 238 // Some sort of hashing could be done to deduplicate instead. 239 // This also breaks on array/struct types. 240 for _, prev := range si.seen { 241 if rowEqual(prev, r) { 242 return false 243 } 244 } 245 si.seen = append(si.seen, r) 246 return true 247} 248 249// offsetIter applies an OFFSET clause. 250type offsetIter struct { 251 ri rowIter 252 skip int64 253} 254 255func (oi *offsetIter) Cols() []colInfo { return oi.ri.Cols() } 256func (oi *offsetIter) Next() (row, error) { 257 for oi.skip > 0 { 258 _, err := oi.ri.Next() 259 if err != nil { 260 return nil, err 261 } 262 oi.skip-- 263 } 264 row, err := oi.ri.Next() 265 if err != nil { 266 return nil, err 267 } 268 return row, nil 269} 270 271// limitIter applies a LIMIT clause. 272type limitIter struct { 273 ri rowIter 274 rem int64 275} 276 277func (li *limitIter) Cols() []colInfo { return li.ri.Cols() } 278func (li *limitIter) Next() (row, error) { 279 if li.rem <= 0 { 280 return nil, io.EOF 281 } 282 row, err := li.ri.Next() 283 if err != nil { 284 return nil, err 285 } 286 li.rem-- 287 return row, nil 288} 289 290type queryParam struct { 291 Value interface{} // internal representation 292 Type spansql.Type 293} 294 295type queryParams map[string]queryParam // TODO: change key to spansql.Param? 296 297type queryContext struct { 298 params queryParams 299 300 tables []*table // sorted by name 301 tableIndex map[spansql.ID]*table 302 locks int 303} 304 305func (qc *queryContext) Lock() { 306 // Take locks in name order. 307 for _, t := range qc.tables { 308 t.mu.Lock() 309 qc.locks++ 310 } 311} 312 313func (qc *queryContext) Unlock() { 314 for _, t := range qc.tables { 315 t.mu.Unlock() 316 qc.locks-- 317 } 318} 319 320func (d *database) Query(q spansql.Query, params queryParams) (ri rowIter, err error) { 321 // Figure out the context of the query and take any required locks. 322 qc, err := d.queryContext(q, params) 323 if err != nil { 324 return nil, err 325 } 326 qc.Lock() 327 // On the way out, if there were locks taken, flatten the output 328 // and release the locks. 329 if qc.locks > 0 { 330 defer func() { 331 if err == nil { 332 ri, err = toRawIter(ri) 333 } 334 qc.Unlock() 335 }() 336 } 337 338 // Prepare auxiliary expressions to evaluate for ORDER BY. 339 var aux []spansql.Expr 340 var desc []bool 341 for _, o := range q.Order { 342 aux = append(aux, o.Expr) 343 desc = append(desc, o.Desc) 344 } 345 346 si, err := d.evalSelect(q.Select, qc) 347 if err != nil { 348 return nil, err 349 } 350 ri = si 351 352 // Apply ORDER BY. 353 if len(q.Order) > 0 { 354 // Evaluate the selIter completely, and sort the rows by the auxiliary expressions. 355 rows, keys, err := evalSelectOrder(si, aux) 356 if err != nil { 357 return nil, err 358 } 359 sort.Sort(externalRowSorter{rows: rows, keys: keys, desc: desc}) 360 ri = &rawIter{cols: si.cis, rows: rows} 361 } 362 363 // Apply LIMIT, OFFSET. 364 if q.Limit != nil { 365 if q.Offset != nil { 366 off, err := evalLiteralOrParam(q.Offset, params) 367 if err != nil { 368 return nil, err 369 } 370 ri = &offsetIter{ri: ri, skip: off} 371 } 372 373 lim, err := evalLiteralOrParam(q.Limit, params) 374 if err != nil { 375 return nil, err 376 } 377 ri = &limitIter{ri: ri, rem: lim} 378 } 379 380 return ri, nil 381} 382 383func (d *database) queryContext(q spansql.Query, params queryParams) (*queryContext, error) { 384 qc := &queryContext{ 385 params: params, 386 } 387 388 // Look for any mentioned tables and add them to qc.tableIndex. 389 addTable := func(name spansql.ID) error { 390 if _, ok := qc.tableIndex[name]; ok { 391 return nil // Already found this table. 392 } 393 t, err := d.table(name) 394 if err != nil { 395 return err 396 } 397 if qc.tableIndex == nil { 398 qc.tableIndex = make(map[spansql.ID]*table) 399 } 400 qc.tableIndex[name] = t 401 return nil 402 } 403 var findTables func(sf spansql.SelectFrom) error 404 findTables = func(sf spansql.SelectFrom) error { 405 switch sf := sf.(type) { 406 default: 407 return fmt.Errorf("can't prepare query context for SelectFrom of type %T", sf) 408 case spansql.SelectFromTable: 409 return addTable(sf.Table) 410 case spansql.SelectFromJoin: 411 if err := findTables(sf.LHS); err != nil { 412 return err 413 } 414 return findTables(sf.RHS) 415 case spansql.SelectFromUnnest: 416 // TODO: if array paths get supported, this will need more work. 417 return nil 418 } 419 } 420 for _, sf := range q.Select.From { 421 if err := findTables(sf); err != nil { 422 return nil, err 423 } 424 } 425 426 // Build qc.tables in name order so we can take locks in a well-defined order. 427 var names []spansql.ID 428 for name := range qc.tableIndex { 429 names = append(names, name) 430 } 431 sort.Slice(names, func(i, j int) bool { return names[i] < names[j] }) 432 for _, name := range names { 433 qc.tables = append(qc.tables, qc.tableIndex[name]) 434 } 435 436 return qc, nil 437} 438 439func (d *database) evalSelect(sel spansql.Select, qc *queryContext) (si *selIter, evalErr error) { 440 var ri rowIter = &nullIter{} 441 ec := evalContext{ 442 params: qc.params, 443 } 444 445 // First stage is to identify the data source. 446 // If there's a FROM then that names a table to use. 447 if len(sel.From) > 1 { 448 return nil, fmt.Errorf("selecting with more than one FROM clause not yet supported") 449 } 450 if len(sel.From) == 1 { 451 var err error 452 ec, ri, err = d.evalSelectFrom(qc, ec, sel.From[0]) 453 if err != nil { 454 return nil, err 455 } 456 } 457 458 // Apply WHERE. 459 if sel.Where != nil { 460 ri = whereIter{ 461 ri: ri, 462 ec: ec, 463 where: sel.Where, 464 } 465 } 466 467 // Load aliases visible to any future iterators, 468 // including GROUP BY and ORDER BY. These are not visible to the WHERE clause. 469 ec.aliases = make(map[spansql.ID]spansql.Expr) 470 for i, alias := range sel.ListAliases { 471 ec.aliases[alias] = sel.List[i] 472 } 473 // TODO: Add aliases for "1", "2", etc. 474 475 // Apply GROUP BY. 476 // This only reorders rows to group rows together; 477 // aggregation happens next. 478 var rowGroups [][2]int // Sequence of half-open intervals of row numbers. 479 if len(sel.GroupBy) > 0 { 480 raw, err := toRawIter(ri) 481 if err != nil { 482 return nil, err 483 } 484 keys := make([][]interface{}, 0, len(raw.rows)) 485 for _, row := range raw.rows { 486 // Evaluate sort key for this row. 487 ec.row = row 488 key, err := ec.evalExprList(sel.GroupBy) 489 if err != nil { 490 return nil, err 491 } 492 keys = append(keys, key) 493 } 494 495 // Reorder rows base on the evaluated keys. 496 ers := externalRowSorter{rows: raw.rows, keys: keys} 497 sort.Sort(ers) 498 raw.rows = ers.rows 499 500 // Record groups as a sequence of row intervals. 501 // Each group is a run of the same keys. 502 start := 0 503 for i := 1; i < len(keys); i++ { 504 if compareValLists(keys[i-1], keys[i], nil) == 0 { 505 continue 506 } 507 rowGroups = append(rowGroups, [2]int{start, i}) 508 start = i 509 } 510 if len(keys) > 0 { 511 rowGroups = append(rowGroups, [2]int{start, len(keys)}) 512 } 513 514 // Clear aliases, since they aren't visible elsewhere. 515 ec.aliases = nil 516 517 ri = raw 518 } 519 520 // Handle aggregation. 521 var aggI []int 522 for i, e := range sel.List { 523 // Supported aggregate funcs have exactly one arg. 524 f, ok := e.(spansql.Func) 525 if !ok || len(f.Args) != 1 { 526 continue 527 } 528 _, ok = aggregateFuncs[f.Name] 529 if !ok { 530 continue 531 } 532 aggI = append(aggI, i) 533 } 534 if len(aggI) > 0 { 535 raw, err := toRawIter(ri) 536 if err != nil { 537 return nil, err 538 } 539 if len(sel.GroupBy) == 0 { 540 // No grouping, so aggregation applies to the entire table (e.g. COUNT(*)). 541 // This may result in a [0,0) entry for empty inputs. 542 rowGroups = [][2]int{{0, len(raw.rows)}} 543 } 544 545 // Prepare output. 546 rawOut := &rawIter{ 547 // Same as input columns, but also the aggregate value. 548 // Add the colInfo for the aggregate at the end 549 // so we know the type. 550 // Make a copy for safety. 551 cols: append([]colInfo(nil), raw.cols...), 552 } 553 554 aggType := make([]*spansql.Type, len(aggI)) 555 for _, rg := range rowGroups { 556 var outRow row 557 // Output for the row group is the first row of the group (arbitrary, 558 // but it should be representative), and the aggregate value. 559 // TODO: Should this exclude the aggregated expressions so they can't be selected? 560 // If the row group is empty then only the aggregation value is used; 561 // this covers things like COUNT(*) with no matching rows. 562 if rg[0] < len(raw.rows) { 563 repRow := raw.rows[rg[0]] 564 for i := range repRow { 565 outRow = append(outRow, repRow.copyDataElem(i)) 566 } 567 } else { 568 // Fill with NULLs to keep the rows and colInfo aligned. 569 for i := 0; i < len(rawOut.cols); i++ { 570 outRow = append(outRow, nil) 571 } 572 } 573 574 for j, aggI := range aggI { 575 fexpr := sel.List[aggI].(spansql.Func) 576 fn := aggregateFuncs[fexpr.Name] 577 starArg := fexpr.Args[0] == spansql.Star 578 if starArg && !fn.AcceptStar { 579 return nil, fmt.Errorf("aggregate function %s does not accept * as an argument", fexpr.Name) 580 } 581 var argType spansql.Type 582 if !starArg { 583 ci, err := ec.colInfo(fexpr.Args[0]) 584 if err != nil { 585 return nil, fmt.Errorf("evaluating aggregate function %s arg type: %v", fexpr.Name, err) 586 } 587 argType = ci.Type 588 } 589 590 // Compute aggregate value across this group. 591 var values []interface{} 592 for i := rg[0]; i < rg[1]; i++ { 593 ec.row = raw.rows[i] 594 if starArg { 595 // A non-NULL placeholder is sufficient for aggregation. 596 values = append(values, 1) 597 } else { 598 x, err := ec.evalExpr(fexpr.Args[0]) 599 if err != nil { 600 return nil, err 601 } 602 values = append(values, x) 603 } 604 } 605 x, typ, err := fn.Eval(values, argType) 606 if err != nil { 607 return nil, err 608 } 609 aggType[j] = &typ 610 611 outRow = append(outRow, x) 612 } 613 rawOut.rows = append(rawOut.rows, outRow) 614 } 615 616 for j, aggI := range aggI { 617 fexpr := sel.List[aggI].(spansql.Func) 618 if aggType[j] == nil { 619 // Fallback; there might not be any groups. 620 // TODO: Should this be in aggregateFunc? 621 aggType[j] = &int64Type 622 } 623 rawOut.cols = append(rawOut.cols, colInfo{ 624 Name: spansql.ID(fexpr.SQL()), // TODO: this is a bit hokey, but it is output only 625 Type: *aggType[j], 626 AggIndex: aggI + 1, 627 }) 628 sel.List[aggI] = aggSentinel{ // Mutate query so evalExpr in selIter picks out the new value. 629 Type: *aggType[j], 630 AggIndex: aggI + 1, 631 } 632 } 633 ri = rawOut 634 ec.cols = rawOut.cols 635 } 636 637 // TODO: Support table sampling. 638 639 // Apply SELECT list. 640 var colInfos []colInfo 641 for i, e := range sel.List { 642 if e == spansql.Star { 643 colInfos = append(colInfos, ec.cols...) 644 } else { 645 ci, err := ec.colInfo(e) 646 if err != nil { 647 return nil, err 648 } 649 if len(sel.ListAliases) > 0 { 650 alias := sel.ListAliases[i] 651 if alias != "" { 652 ci.Name = alias 653 } 654 } 655 // TODO: deal with ci.Name == ""? 656 colInfos = append(colInfos, ci) 657 } 658 } 659 660 return &selIter{ 661 ri: ri, 662 ec: ec, 663 cis: colInfos, 664 list: sel.List, 665 666 distinct: sel.Distinct, // Apply DISTINCT. 667 }, nil 668} 669 670func (d *database) evalSelectFrom(qc *queryContext, ec evalContext, sf spansql.SelectFrom) (evalContext, rowIter, error) { 671 switch sf := sf.(type) { 672 default: 673 return ec, nil, fmt.Errorf("selecting with FROM clause of type %T not yet supported", sf) 674 case spansql.SelectFromTable: 675 t, ok := qc.tableIndex[sf.Table] 676 if !ok { 677 // This shouldn't be possible; the queryContext should have discovered missing tables already. 678 return ec, nil, fmt.Errorf("unknown table %q", sf.Table) 679 } 680 ti := &tableIter{t: t} 681 if sf.Alias != "" { 682 ti.alias = sf.Alias 683 } else { 684 // There is an implicit alias using the table name. 685 // https://cloud.google.com/spanner/docs/query-syntax#implicit_aliases 686 ti.alias = sf.Table 687 } 688 ec.cols = ti.Cols() 689 return ec, ti, nil 690 case spansql.SelectFromJoin: 691 // TODO: Avoid the toRawIter calls here by doing the RHS recursive evalSelectFrom in joinIter.Next on demand. 692 693 lhsEC, lhs, err := d.evalSelectFrom(qc, ec, sf.LHS) 694 if err != nil { 695 return ec, nil, err 696 } 697 lhsRaw, err := toRawIter(lhs) 698 if err != nil { 699 return ec, nil, err 700 } 701 702 rhsEC, rhs, err := d.evalSelectFrom(qc, ec, sf.RHS) 703 if err != nil { 704 return ec, nil, err 705 } 706 rhsRaw, err := toRawIter(rhs) 707 if err != nil { 708 return ec, nil, err 709 } 710 711 ji, ec, err := newJoinIter(lhsRaw, rhsRaw, lhsEC, rhsEC, sf) 712 if err != nil { 713 return ec, nil, err 714 } 715 return ec, ji, nil 716 case spansql.SelectFromUnnest: 717 // TODO: Do all relevant types flow through here? Path expressions might be tricky here. 718 col, err := ec.colInfo(sf.Expr) 719 if err != nil { 720 return ec, nil, fmt.Errorf("evaluating type of UNNEST arg: %v", err) 721 } 722 if !col.Type.Array { 723 return ec, nil, fmt.Errorf("type of UNNEST arg is non-array %s", col.Type.SQL()) 724 } 725 // The output of this UNNEST is the non-array version. 726 col.Name = sf.Alias // may be empty 727 col.Type.Array = false 728 729 // Evaluate the expression, and yield a virtual table with one column. 730 e, err := ec.evalExpr(sf.Expr) 731 if err != nil { 732 return ec, nil, fmt.Errorf("evaluating UNNEST arg: %v", err) 733 } 734 arr, ok := e.([]interface{}) 735 if !ok { 736 return ec, nil, fmt.Errorf("evaluating UNNEST arg gave %t, want array", e) 737 } 738 var rows []row 739 for _, v := range arr { 740 rows = append(rows, row{v}) 741 } 742 743 ri := &rawIter{ 744 cols: []colInfo{col}, 745 rows: rows, 746 } 747 ec.cols = ri.cols 748 return ec, ri, nil 749 } 750} 751 752func newJoinIter(lhs, rhs *rawIter, lhsEC, rhsEC evalContext, sfj spansql.SelectFromJoin) (*joinIter, evalContext, error) { 753 if sfj.On != nil && len(sfj.Using) > 0 { 754 return nil, evalContext{}, fmt.Errorf("JOIN may not have both ON and USING clauses") 755 } 756 if sfj.On == nil && len(sfj.Using) == 0 && sfj.Type != spansql.CrossJoin { 757 // TODO: This isn't correct for joining against a non-table. 758 return nil, evalContext{}, fmt.Errorf("non-CROSS JOIN must have ON or USING clause") 759 } 760 761 // Start with the context from the LHS (aliases and params should be the same on both sides). 762 ji := &joinIter{ 763 jt: sfj.Type, 764 ec: lhsEC, 765 766 primary: lhs, 767 secondaryOrig: rhs, 768 769 primaryOffset: 0, 770 secondaryOffset: len(lhsEC.cols), 771 } 772 switch ji.jt { 773 case spansql.LeftJoin: 774 ji.nullPad = true 775 case spansql.RightJoin: 776 ji.nullPad = true 777 // Primary is RHS. 778 ji.ec = rhsEC 779 ji.primary, ji.secondaryOrig = rhs, lhs 780 ji.primaryOffset, ji.secondaryOffset = len(rhsEC.cols), 0 781 case spansql.FullJoin: 782 // FULL JOIN is implemented as a LEFT JOIN with tracking for which rows of the RHS 783 // have been used. Then, at the end of the iteration, the unused RHS rows are emitted. 784 ji.nullPad = true 785 ji.used = make([]bool, 0, 10) // arbitrary preallocation 786 } 787 ji.ec.cols, ji.ec.row = nil, nil 788 789 // Construct a merged evalContext, and prepare the join condition evaluation. 790 // TODO: Remove ambiguous names here? Or catch them when evaluated? 791 // TODO: aliases might need work? 792 if len(sfj.Using) == 0 { 793 ji.prepNonUsing(sfj.On, lhsEC, rhsEC) 794 } else { 795 if err := ji.prepUsing(sfj.Using, lhsEC, rhsEC, ji.jt == spansql.RightJoin); err != nil { 796 return nil, evalContext{}, err 797 } 798 } 799 800 return ji, ji.ec, nil 801} 802 803// prepNonUsing configures the joinIter to evaluate with an ON clause or no join clause. 804// The arg is nil in the latter case. 805func (ji *joinIter) prepNonUsing(on spansql.BoolExpr, lhsEC, rhsEC evalContext) { 806 // Having ON or no clause results in the full set of columns from both sides. 807 // Force a copy. 808 ji.ec.cols = append(ji.ec.cols, lhsEC.cols...) 809 ji.ec.cols = append(ji.ec.cols, rhsEC.cols...) 810 ji.ec.row = make(row, len(ji.ec.cols)) 811 812 ji.cond = func(primary, secondary row) (bool, error) { 813 copy(ji.ec.row[ji.primaryOffset:], primary) 814 copy(ji.ec.row[ji.secondaryOffset:], secondary) 815 if on == nil { 816 // No condition; all rows match. 817 return true, nil 818 } 819 b, err := ji.ec.evalBoolExpr(on) 820 if err != nil { 821 return false, err 822 } 823 return b != nil && *b, nil 824 } 825 ji.zero = func(primary, secondary row) { 826 for i := range ji.ec.row { 827 ji.ec.row[i] = nil 828 } 829 copy(ji.ec.row[ji.primaryOffset:], primary) 830 copy(ji.ec.row[ji.secondaryOffset:], secondary) 831 } 832} 833 834func (ji *joinIter) prepUsing(using []spansql.ID, lhsEC, rhsEC evalContext, flipped bool) error { 835 // Having a USING clause results in the set of named columns once, 836 // followed by the unnamed columns from both sides. 837 838 // lhsUsing is the column indexes in the LHS that the USING clause references. 839 // rhsUsing is similar. 840 // lhsNotUsing/rhsNotUsing are the complement. 841 var lhsUsing, rhsUsing []int 842 var lhsNotUsing, rhsNotUsing []int 843 // lhsUsed, rhsUsed are the set of column indexes in lhsUsing/rhsUsing. 844 lhsUsed, rhsUsed := make(map[int]bool), make(map[int]bool) 845 for _, id := range using { 846 lhsi, err := lhsEC.resolveColumnIndex(id) 847 if err != nil { 848 return err 849 } 850 lhsUsing = append(lhsUsing, lhsi) 851 lhsUsed[lhsi] = true 852 853 rhsi, err := rhsEC.resolveColumnIndex(id) 854 if err != nil { 855 return err 856 } 857 rhsUsing = append(rhsUsing, rhsi) 858 rhsUsed[rhsi] = true 859 860 // TODO: Should this hide or merge column aliases? 861 ji.ec.cols = append(ji.ec.cols, lhsEC.cols[lhsi]) 862 } 863 for i, col := range lhsEC.cols { 864 if !lhsUsed[i] { 865 ji.ec.cols = append(ji.ec.cols, col) 866 lhsNotUsing = append(lhsNotUsing, i) 867 } 868 } 869 for i, col := range rhsEC.cols { 870 if !rhsUsed[i] { 871 ji.ec.cols = append(ji.ec.cols, col) 872 rhsNotUsing = append(rhsNotUsing, i) 873 } 874 } 875 ji.ec.row = make(row, len(ji.ec.cols)) 876 877 primaryUsing, secondaryUsing := lhsUsing, rhsUsing 878 if flipped { 879 primaryUsing, secondaryUsing = secondaryUsing, primaryUsing 880 } 881 882 orNil := func(r row, i int) interface{} { 883 if r == nil { 884 return nil 885 } 886 return r[i] 887 } 888 // populate writes the data to ji.ec.row in the correct positions. 889 populate := func(primary, secondary row) { // either may be nil 890 j := 0 891 if primary != nil { 892 for _, pi := range primaryUsing { 893 ji.ec.row[j] = primary[pi] 894 j++ 895 } 896 } else { 897 for _, si := range secondaryUsing { 898 ji.ec.row[j] = secondary[si] 899 j++ 900 } 901 } 902 lhs, rhs := primary, secondary 903 if flipped { 904 rhs, lhs = lhs, rhs 905 } 906 for _, i := range lhsNotUsing { 907 ji.ec.row[j] = orNil(lhs, i) 908 j++ 909 } 910 for _, i := range rhsNotUsing { 911 ji.ec.row[j] = orNil(rhs, i) 912 j++ 913 } 914 } 915 ji.cond = func(primary, secondary row) (bool, error) { 916 for i, pi := range primaryUsing { 917 si := secondaryUsing[i] 918 if compareVals(primary[pi], secondary[si]) != 0 { 919 return false, nil 920 } 921 } 922 populate(primary, secondary) 923 return true, nil 924 } 925 ji.zero = func(primary, secondary row) { 926 populate(primary, secondary) 927 } 928 return nil 929} 930 931type joinIter struct { 932 jt spansql.JoinType 933 ec evalContext // combined context 934 935 // The "primary" is scanned (consumed), but the secondary is cloned for each primary row. 936 // Most join types have primary==LHS; a RIGHT JOIN is the exception. 937 primary, secondaryOrig *rawIter 938 939 // The offsets into ec.row that the primary/secondary rows should appear 940 // in the final output. Not used when there's a USING clause. 941 primaryOffset, secondaryOffset int 942 // nullPad is whether primary rows without matching secondary rows 943 // should be yielded with null padding (e.g. OUTER JOINs). 944 nullPad bool 945 946 primaryRow row // current row from primary, or nil if it is time to advance 947 secondary *rawIter // current clone of secondary 948 secondaryRead int // number of rows already read from secondary 949 any bool // true if any secondary rows have matched primaryRow 950 951 // cond reports whether the primary and secondary rows "join" (e.g. the ON clause is true). 952 // It populates ec.row with the output. 953 cond func(primary, secondary row) (bool, error) 954 // zero populates ec.row with the primary or secondary row data (either of which may be nil), 955 // and sets the remainder to NULL. 956 // This is used when nullPad is true and a primary or secondary row doesn't match. 957 zero func(primary, secondary row) 958 959 // For FULL JOIN, this tracks the secondary rows that have been used. 960 // It is non-nil when being used. 961 used []bool 962 zeroUnused bool // set when emitting unused secondary rows 963 unusedIndex int // next index of used to check 964} 965 966func (ji *joinIter) Cols() []colInfo { return ji.ec.cols } 967 968func (ji *joinIter) nextPrimary() error { 969 var err error 970 ji.primaryRow, err = ji.primary.Next() 971 if err != nil { 972 return err 973 } 974 ji.secondary = ji.secondaryOrig.clone() 975 ji.secondaryRead = 0 976 ji.any = false 977 return nil 978} 979 980func (ji *joinIter) Next() (row, error) { 981 if ji.primaryRow == nil && !ji.zeroUnused { 982 err := ji.nextPrimary() 983 if err == io.EOF && ji.used != nil { 984 // Drop down to emitting unused secondary rows. 985 ji.zeroUnused = true 986 ji.secondary = nil 987 goto scanJiUsed 988 } 989 if err != nil { 990 return nil, err 991 } 992 } 993scanJiUsed: 994 if ji.zeroUnused { 995 if ji.secondary == nil { 996 ji.secondary = ji.secondaryOrig.clone() 997 ji.secondaryRead = 0 998 } 999 for ji.unusedIndex < len(ji.used) && ji.used[ji.unusedIndex] { 1000 ji.unusedIndex++ 1001 } 1002 if ji.unusedIndex >= len(ji.used) || ji.secondaryRead >= len(ji.used) { 1003 // Truly finished. 1004 return nil, io.EOF 1005 } 1006 var secondaryRow row 1007 for ji.secondaryRead <= ji.unusedIndex { 1008 var err error 1009 secondaryRow, err = ji.secondary.Next() 1010 if err != nil { 1011 return nil, err 1012 } 1013 ji.secondaryRead++ 1014 } 1015 ji.zero(nil, secondaryRow) 1016 return ji.ec.row, nil 1017 } 1018 1019 for { 1020 secondaryRow, err := ji.secondary.Next() 1021 if err == io.EOF { 1022 // Finished the current primary row. 1023 1024 if !ji.any && ji.nullPad { 1025 ji.zero(ji.primaryRow, nil) 1026 ji.primaryRow = nil 1027 return ji.ec.row, nil 1028 } 1029 1030 // Advance to next one. 1031 err := ji.nextPrimary() 1032 if err == io.EOF && ji.used != nil { 1033 ji.zeroUnused = true 1034 ji.secondary = nil 1035 goto scanJiUsed 1036 } 1037 if err != nil { 1038 return nil, err 1039 } 1040 continue 1041 } 1042 if err != nil { 1043 return nil, err 1044 } 1045 ji.secondaryRead++ 1046 if ji.used != nil { 1047 for len(ji.used) < ji.secondaryRead { 1048 ji.used = append(ji.used, false) 1049 } 1050 } 1051 1052 // We have a pair of rows to consider. 1053 match, err := ji.cond(ji.primaryRow, secondaryRow) 1054 if err != nil { 1055 return nil, err 1056 } 1057 if !match { 1058 continue 1059 } 1060 ji.any = true 1061 if ji.used != nil { 1062 // Make a note that we used this secondary row. 1063 ji.used[ji.secondaryRead-1] = true 1064 } 1065 return ji.ec.row, nil 1066 } 1067} 1068 1069func evalSelectOrder(si *selIter, aux []spansql.Expr) (rows []row, keys [][]interface{}, err error) { 1070 // This is like toRawIter except it also evaluates the auxiliary expressions for ORDER BY. 1071 for { 1072 r, err := si.Next() 1073 if err == io.EOF { 1074 break 1075 } else if err != nil { 1076 return nil, nil, err 1077 } 1078 key, err := si.ec.evalExprList(aux) 1079 if err != nil { 1080 return nil, nil, err 1081 } 1082 1083 rows = append(rows, r.copyAllData()) 1084 keys = append(keys, key) 1085 } 1086 return 1087} 1088 1089// externalRowSorter implements sort.Interface for a slice of rows 1090// with an external sort key. 1091type externalRowSorter struct { 1092 rows []row 1093 keys [][]interface{} 1094 desc []bool // may be nil 1095} 1096 1097func (ers externalRowSorter) Len() int { return len(ers.rows) } 1098func (ers externalRowSorter) Less(i, j int) bool { 1099 return compareValLists(ers.keys[i], ers.keys[j], ers.desc) < 0 1100} 1101func (ers externalRowSorter) Swap(i, j int) { 1102 ers.rows[i], ers.rows[j] = ers.rows[j], ers.rows[i] 1103 ers.keys[i], ers.keys[j] = ers.keys[j], ers.keys[i] 1104} 1105