1/* 2Copyright 2020 Google LLC 3 4Licensed under the Apache License, Version 2.0 (the "License"); 5you may not use this file except in compliance with the License. 6You may obtain a copy of the License at 7 8 http://www.apache.org/licenses/LICENSE-2.0 9 10Unless required by applicable law or agreed to in writing, software 11distributed under the License is distributed on an "AS IS" BASIS, 12WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13See the License for the specific language governing permissions and 14limitations under the License. 15*/ 16 17package spannertest 18 19import ( 20 "fmt" 21 "io" 22 "sort" 23 24 "cloud.google.com/go/spanner/spansql" 25) 26 27/* 28There's several ways to conceptualise SQL queries. The simplest, and what 29we implement here, is a series of pipelines that transform the data, whether 30pulling from a table (FROM tbl), filtering (WHERE expr), re-ordering (ORDER BY expr) 31or other transformations. 32 33The order of operations among those supported by Cloud Spanner is 34 FROM + JOIN + set ops [TODO: set ops] 35 WHERE 36 GROUP BY 37 aggregation 38 HAVING [TODO] 39 SELECT 40 DISTINCT 41 ORDER BY 42 OFFSET 43 LIMIT 44*/ 45 46// rowIter represents some iteration over rows of data. 47// It is returned by reads and queries. 48type rowIter interface { 49 // Cols returns the metadata about the returned data. 50 Cols() []colInfo 51 52 // Next returns the next row. 53 // If done, it returns (nil, io.EOF). 54 Next() (row, error) 55} 56 57// aggSentinel is a synthetic expression that refers to an aggregated value. 58// It is transient only; it is never stored and only used during evaluation. 59type aggSentinel struct { 60 spansql.Expr 61 Type spansql.Type 62 AggIndex int // Index+1 of SELECT list. 63} 64 65// nullIter is a rowIter that returns one empty row only. 66// This is used for queries without a table. 67type nullIter struct { 68 done bool 69} 70 71func (ni *nullIter) Cols() []colInfo { return nil } 72func (ni *nullIter) Next() (row, error) { 73 if ni.done { 74 return nil, io.EOF 75 } 76 ni.done = true 77 return nil, nil 78} 79 80// tableIter is a rowIter that walks a table. 81// It assumes the table is locked for the duration. 82type tableIter struct { 83 t *table 84 rowIndex int // index of next row to return 85 86 alias spansql.ID // if non-empty, "AS <alias>" 87} 88 89func (ti *tableIter) Cols() []colInfo { 90 // Build colInfo in the original column order. 91 cis := make([]colInfo, len(ti.t.cols)) 92 for _, ci := range ti.t.cols { 93 if ti.alias != "" { 94 ci.Alias = spansql.PathExp{ti.alias, ci.Name} 95 } 96 cis[ti.t.origIndex[ci.Name]] = ci 97 } 98 return cis 99} 100 101func (ti *tableIter) Next() (row, error) { 102 if ti.rowIndex >= len(ti.t.rows) { 103 return nil, io.EOF 104 } 105 r := ti.t.rows[ti.rowIndex] 106 ti.rowIndex++ 107 108 // Build output row in the original column order. 109 res := make(row, len(r)) 110 for i, ci := range ti.t.cols { 111 res[ti.t.origIndex[ci.Name]] = r[i] 112 } 113 114 return res, nil 115} 116 117// rawIter is a rowIter with fixed data. 118type rawIter struct { 119 // cols is the metadata about the returned data. 120 cols []colInfo 121 122 // rows holds the result data itself. 123 rows []row 124} 125 126func (raw *rawIter) Cols() []colInfo { return raw.cols } 127func (raw *rawIter) Next() (row, error) { 128 if len(raw.rows) == 0 { 129 return nil, io.EOF 130 } 131 res := raw.rows[0] 132 raw.rows = raw.rows[1:] 133 return res, nil 134} 135 136func (raw *rawIter) add(src row, colIndexes []int) { 137 raw.rows = append(raw.rows, src.copyData(colIndexes)) 138} 139 140// clone makes a shallow copy. 141func (raw *rawIter) clone() *rawIter { 142 return &rawIter{cols: raw.cols, rows: raw.rows} 143} 144 145func toRawIter(ri rowIter) (*rawIter, error) { 146 if raw, ok := ri.(*rawIter); ok { 147 return raw, nil 148 } 149 raw := &rawIter{cols: ri.Cols()} 150 for { 151 row, err := ri.Next() 152 if err == io.EOF { 153 break 154 } else if err != nil { 155 return nil, err 156 } 157 raw.rows = append(raw.rows, row.copyAllData()) 158 } 159 return raw, nil 160} 161 162// whereIter applies a WHERE clause. 163type whereIter struct { 164 ri rowIter 165 ec evalContext 166 where spansql.BoolExpr 167} 168 169func (wi whereIter) Cols() []colInfo { return wi.ri.Cols() } 170func (wi whereIter) Next() (row, error) { 171 for { 172 row, err := wi.ri.Next() 173 if err != nil { 174 return nil, err 175 } 176 wi.ec.row = row 177 178 b, err := wi.ec.evalBoolExpr(wi.where) 179 if err != nil { 180 return nil, err 181 } 182 if b != nil && *b { 183 return row, nil 184 } 185 } 186} 187 188// selIter applies a SELECT list. 189type selIter struct { 190 ri rowIter 191 ec evalContext 192 cis []colInfo 193 list []spansql.Expr 194 195 distinct bool // whether this is a SELECT DISTINCT 196 seen []row 197} 198 199func (si *selIter) Cols() []colInfo { return si.cis } 200func (si *selIter) Next() (row, error) { 201 for { 202 r, err := si.next() 203 if err != nil { 204 return nil, err 205 } 206 if si.distinct && !si.keep(r) { 207 continue 208 } 209 return r, nil 210 } 211} 212 213// next retrieves the next row for the SELECT and evaluates its expression list. 214func (si *selIter) next() (row, error) { 215 r, err := si.ri.Next() 216 if err != nil { 217 return nil, err 218 } 219 si.ec.row = r 220 221 var out row 222 for _, e := range si.list { 223 if e == spansql.Star { 224 out = append(out, r...) 225 } else { 226 v, err := si.ec.evalExpr(e) 227 if err != nil { 228 return nil, err 229 } 230 out = append(out, v) 231 } 232 } 233 return out, nil 234} 235 236func (si *selIter) keep(r row) bool { 237 // This is hilariously inefficient; O(N^2) in the number of returned rows. 238 // Some sort of hashing could be done to deduplicate instead. 239 // This also breaks on array/struct types. 240 for _, prev := range si.seen { 241 if rowEqual(prev, r) { 242 return false 243 } 244 } 245 si.seen = append(si.seen, r) 246 return true 247} 248 249// offsetIter applies an OFFSET clause. 250type offsetIter struct { 251 ri rowIter 252 skip int64 253} 254 255func (oi *offsetIter) Cols() []colInfo { return oi.ri.Cols() } 256func (oi *offsetIter) Next() (row, error) { 257 for oi.skip > 0 { 258 _, err := oi.ri.Next() 259 if err != nil { 260 return nil, err 261 } 262 oi.skip-- 263 } 264 row, err := oi.ri.Next() 265 if err != nil { 266 return nil, err 267 } 268 return row, nil 269} 270 271// limitIter applies a LIMIT clause. 272type limitIter struct { 273 ri rowIter 274 rem int64 275} 276 277func (li *limitIter) Cols() []colInfo { return li.ri.Cols() } 278func (li *limitIter) Next() (row, error) { 279 if li.rem <= 0 { 280 return nil, io.EOF 281 } 282 row, err := li.ri.Next() 283 if err != nil { 284 return nil, err 285 } 286 li.rem-- 287 return row, nil 288} 289 290type queryParam struct { 291 Value interface{} // internal representation 292 Type spansql.Type 293} 294 295type queryParams map[string]queryParam // TODO: change key to spansql.Param? 296 297type queryContext struct { 298 params queryParams 299 300 tables []*table // sorted by name 301 tableIndex map[spansql.ID]*table 302 locks int 303} 304 305func (qc *queryContext) Lock() { 306 // Take locks in name order. 307 for _, t := range qc.tables { 308 t.mu.Lock() 309 qc.locks++ 310 } 311} 312 313func (qc *queryContext) Unlock() { 314 for _, t := range qc.tables { 315 t.mu.Unlock() 316 qc.locks-- 317 } 318} 319 320func (d *database) Query(q spansql.Query, params queryParams) (ri rowIter, err error) { 321 // Figure out the context of the query and take any required locks. 322 qc, err := d.queryContext(q, params) 323 if err != nil { 324 return nil, err 325 } 326 qc.Lock() 327 // On the way out, if there were locks taken, flatten the output 328 // and release the locks. 329 if qc.locks > 0 { 330 defer func() { 331 if err == nil { 332 ri, err = toRawIter(ri) 333 } 334 qc.Unlock() 335 }() 336 } 337 338 // Prepare auxiliary expressions to evaluate for ORDER BY. 339 var aux []spansql.Expr 340 var desc []bool 341 for _, o := range q.Order { 342 aux = append(aux, o.Expr) 343 desc = append(desc, o.Desc) 344 } 345 346 si, err := d.evalSelect(q.Select, qc) 347 if err != nil { 348 return nil, err 349 } 350 ri = si 351 352 // Apply ORDER BY. 353 if len(q.Order) > 0 { 354 // Evaluate the selIter completely, and sort the rows by the auxiliary expressions. 355 rows, keys, err := evalSelectOrder(si, aux) 356 if err != nil { 357 return nil, err 358 } 359 sort.Sort(externalRowSorter{rows: rows, keys: keys, desc: desc}) 360 ri = &rawIter{cols: si.cis, rows: rows} 361 } 362 363 // Apply LIMIT, OFFSET. 364 if q.Limit != nil { 365 if q.Offset != nil { 366 off, err := evalLiteralOrParam(q.Offset, params) 367 if err != nil { 368 return nil, err 369 } 370 ri = &offsetIter{ri: ri, skip: off} 371 } 372 373 lim, err := evalLiteralOrParam(q.Limit, params) 374 if err != nil { 375 return nil, err 376 } 377 ri = &limitIter{ri: ri, rem: lim} 378 } 379 380 return ri, nil 381} 382 383func (d *database) queryContext(q spansql.Query, params queryParams) (*queryContext, error) { 384 qc := &queryContext{ 385 params: params, 386 } 387 388 // Look for any mentioned tables and add them to qc.tableIndex. 389 addTable := func(name spansql.ID) error { 390 if _, ok := qc.tableIndex[name]; ok { 391 return nil // Already found this table. 392 } 393 t, err := d.table(name) 394 if err != nil { 395 return err 396 } 397 if qc.tableIndex == nil { 398 qc.tableIndex = make(map[spansql.ID]*table) 399 } 400 qc.tableIndex[name] = t 401 return nil 402 } 403 var findTables func(sf spansql.SelectFrom) error 404 findTables = func(sf spansql.SelectFrom) error { 405 switch sf := sf.(type) { 406 default: 407 return fmt.Errorf("can't prepare query context for SelectFrom of type %T", sf) 408 case spansql.SelectFromTable: 409 return addTable(sf.Table) 410 case spansql.SelectFromJoin: 411 if err := findTables(sf.LHS); err != nil { 412 return err 413 } 414 return findTables(sf.RHS) 415 case spansql.SelectFromUnnest: 416 // TODO: if array paths get supported, this will need more work. 417 return nil 418 } 419 } 420 for _, sf := range q.Select.From { 421 if err := findTables(sf); err != nil { 422 return nil, err 423 } 424 } 425 426 // Build qc.tables in name order so we can take locks in a well-defined order. 427 var names []spansql.ID 428 for name := range qc.tableIndex { 429 names = append(names, name) 430 } 431 sort.Slice(names, func(i, j int) bool { return names[i] < names[j] }) 432 for _, name := range names { 433 qc.tables = append(qc.tables, qc.tableIndex[name]) 434 } 435 436 return qc, nil 437} 438 439func (d *database) evalSelect(sel spansql.Select, qc *queryContext) (si *selIter, evalErr error) { 440 var ri rowIter = &nullIter{} 441 ec := evalContext{ 442 params: qc.params, 443 } 444 445 // First stage is to identify the data source. 446 // If there's a FROM then that names a table to use. 447 if len(sel.From) > 1 { 448 return nil, fmt.Errorf("selecting with more than one FROM clause not yet supported") 449 } 450 if len(sel.From) == 1 { 451 var err error 452 ec, ri, err = d.evalSelectFrom(qc, ec, sel.From[0]) 453 if err != nil { 454 return nil, err 455 } 456 } 457 458 // Apply WHERE. 459 if sel.Where != nil { 460 ri = whereIter{ 461 ri: ri, 462 ec: ec, 463 where: sel.Where, 464 } 465 } 466 467 // Load aliases visible to any future iterators, 468 // including GROUP BY and ORDER BY. These are not visible to the WHERE clause. 469 ec.aliases = make(map[spansql.ID]spansql.Expr) 470 for i, alias := range sel.ListAliases { 471 ec.aliases[alias] = sel.List[i] 472 } 473 // TODO: Add aliases for "1", "2", etc. 474 475 // Apply GROUP BY. 476 // This only reorders rows to group rows together; 477 // aggregation happens next. 478 var rowGroups [][2]int // Sequence of half-open intervals of row numbers. 479 if len(sel.GroupBy) > 0 { 480 raw, err := toRawIter(ri) 481 if err != nil { 482 return nil, err 483 } 484 keys := make([][]interface{}, 0, len(raw.rows)) 485 for _, row := range raw.rows { 486 // Evaluate sort key for this row. 487 ec.row = row 488 key, err := ec.evalExprList(sel.GroupBy) 489 if err != nil { 490 return nil, err 491 } 492 keys = append(keys, key) 493 } 494 495 // Reorder rows base on the evaluated keys. 496 ers := externalRowSorter{rows: raw.rows, keys: keys} 497 sort.Sort(ers) 498 raw.rows = ers.rows 499 500 // Record groups as a sequence of row intervals. 501 // Each group is a run of the same keys. 502 start := 0 503 for i := 1; i < len(keys); i++ { 504 if compareValLists(keys[i-1], keys[i], nil) == 0 { 505 continue 506 } 507 rowGroups = append(rowGroups, [2]int{start, i}) 508 start = i 509 } 510 if len(keys) > 0 { 511 rowGroups = append(rowGroups, [2]int{start, len(keys)}) 512 } 513 514 // Clear aliases, since they aren't visible elsewhere. 515 ec.aliases = nil 516 517 ri = raw 518 } 519 520 // Handle aggregation. 521 // TODO: Support more than one aggregation function; does Spanner support that? 522 aggI := -1 523 for i, e := range sel.List { 524 // Supported aggregate funcs have exactly one arg. 525 f, ok := e.(spansql.Func) 526 if !ok || len(f.Args) != 1 { 527 continue 528 } 529 _, ok = aggregateFuncs[f.Name] 530 if !ok { 531 continue 532 } 533 if aggI > -1 { 534 return nil, fmt.Errorf("only one aggregate function is supported") 535 } 536 aggI = i 537 } 538 if aggI > -1 { 539 raw, err := toRawIter(ri) 540 if err != nil { 541 return nil, err 542 } 543 if len(sel.GroupBy) == 0 { 544 // No grouping, so aggregation applies to the entire table (e.g. COUNT(*)). 545 // This may result in a [0,0) entry for empty inputs. 546 rowGroups = [][2]int{{0, len(raw.rows)}} 547 } 548 fexpr := sel.List[aggI].(spansql.Func) 549 fn := aggregateFuncs[fexpr.Name] 550 starArg := fexpr.Args[0] == spansql.Star 551 if starArg && !fn.AcceptStar { 552 return nil, fmt.Errorf("aggregate function %s does not accept * as an argument", fexpr.Name) 553 } 554 var argType spansql.Type 555 if !starArg { 556 ci, err := ec.colInfo(fexpr.Args[0]) 557 if err != nil { 558 return nil, fmt.Errorf("evaluating aggregate function %s arg type: %v", fexpr.Name, err) 559 } 560 argType = ci.Type 561 } 562 563 // Prepare output. 564 rawOut := &rawIter{ 565 // Same as input columns, but also the aggregate value. 566 // Add the colInfo for the aggregate at the end 567 // so we know the type. 568 // Make a copy for safety. 569 cols: append([]colInfo(nil), raw.cols...), 570 } 571 572 var aggType spansql.Type 573 for _, rg := range rowGroups { 574 // Compute aggregate value across this group. 575 var values []interface{} 576 for i := rg[0]; i < rg[1]; i++ { 577 ec.row = raw.rows[i] 578 if starArg { 579 // A non-NULL placeholder is sufficient for aggregation. 580 values = append(values, 1) 581 } else { 582 x, err := ec.evalExpr(fexpr.Args[0]) 583 if err != nil { 584 return nil, err 585 } 586 values = append(values, x) 587 } 588 } 589 x, typ, err := fn.Eval(values, argType) 590 if err != nil { 591 return nil, err 592 } 593 aggType = typ 594 595 var outRow row 596 // Output for the row group is the first row of the group (arbitrary, 597 // but it should be representative), and the aggregate value. 598 // TODO: Should this exclude the aggregated expressions so they can't be selected? 599 // If the row group is empty then only the aggregation value is used; 600 // this covers things like COUNT(*) with no matching rows. 601 if rg[0] < len(raw.rows) { 602 repRow := raw.rows[rg[0]] 603 for i := range repRow { 604 outRow = append(outRow, repRow.copyDataElem(i)) 605 } 606 } else { 607 // Fill with NULLs to keep the rows and colInfo aligned. 608 for i := 0; i < len(rawOut.cols); i++ { 609 outRow = append(outRow, nil) 610 } 611 } 612 outRow = append(outRow, x) 613 rawOut.rows = append(rawOut.rows, outRow) 614 } 615 616 if aggType == (spansql.Type{}) { 617 // Fallback; there might not be any groups. 618 // TODO: Should this be in aggregateFunc? 619 aggType = int64Type 620 } 621 rawOut.cols = append(raw.cols, colInfo{ 622 Name: spansql.ID(fexpr.SQL()), // TODO: this is a bit hokey, but it is output only 623 Type: aggType, 624 AggIndex: aggI + 1, 625 }) 626 627 ri = rawOut 628 ec.cols = rawOut.cols 629 sel.List[aggI] = aggSentinel{ // Mutate query so evalExpr in selIter picks out the new value. 630 Type: aggType, 631 AggIndex: aggI + 1, 632 } 633 } 634 635 // TODO: Support table sampling. 636 637 // Apply SELECT list. 638 var colInfos []colInfo 639 for i, e := range sel.List { 640 if e == spansql.Star { 641 colInfos = append(colInfos, ec.cols...) 642 } else { 643 ci, err := ec.colInfo(e) 644 if err != nil { 645 return nil, err 646 } 647 if len(sel.ListAliases) > 0 { 648 alias := sel.ListAliases[i] 649 if alias != "" { 650 ci.Name = alias 651 } 652 } 653 // TODO: deal with ci.Name == ""? 654 colInfos = append(colInfos, ci) 655 } 656 } 657 658 return &selIter{ 659 ri: ri, 660 ec: ec, 661 cis: colInfos, 662 list: sel.List, 663 664 distinct: sel.Distinct, // Apply DISTINCT. 665 }, nil 666} 667 668func (d *database) evalSelectFrom(qc *queryContext, ec evalContext, sf spansql.SelectFrom) (evalContext, rowIter, error) { 669 switch sf := sf.(type) { 670 default: 671 return ec, nil, fmt.Errorf("selecting with FROM clause of type %T not yet supported", sf) 672 case spansql.SelectFromTable: 673 t, ok := qc.tableIndex[sf.Table] 674 if !ok { 675 // This shouldn't be possible; the queryContext should have discovered missing tables already. 676 return ec, nil, fmt.Errorf("unknown table %q", sf.Table) 677 } 678 ti := &tableIter{t: t} 679 if sf.Alias != "" { 680 ti.alias = sf.Alias 681 } else { 682 // There is an implicit alias using the table name. 683 // https://cloud.google.com/spanner/docs/query-syntax#implicit_aliases 684 ti.alias = sf.Table 685 } 686 ec.cols = ti.Cols() 687 return ec, ti, nil 688 case spansql.SelectFromJoin: 689 // TODO: Avoid the toRawIter calls here by doing the RHS recursive evalSelectFrom in joinIter.Next on demand. 690 691 lhsEC, lhs, err := d.evalSelectFrom(qc, ec, sf.LHS) 692 if err != nil { 693 return ec, nil, err 694 } 695 lhsRaw, err := toRawIter(lhs) 696 if err != nil { 697 return ec, nil, err 698 } 699 700 rhsEC, rhs, err := d.evalSelectFrom(qc, ec, sf.RHS) 701 if err != nil { 702 return ec, nil, err 703 } 704 rhsRaw, err := toRawIter(rhs) 705 if err != nil { 706 return ec, nil, err 707 } 708 709 ji, ec, err := newJoinIter(lhsRaw, rhsRaw, lhsEC, rhsEC, sf) 710 if err != nil { 711 return ec, nil, err 712 } 713 return ec, ji, nil 714 case spansql.SelectFromUnnest: 715 // TODO: Do all relevant types flow through here? Path expressions might be tricky here. 716 col, err := ec.colInfo(sf.Expr) 717 if err != nil { 718 return ec, nil, fmt.Errorf("evaluating type of UNNEST arg: %v", err) 719 } 720 if !col.Type.Array { 721 return ec, nil, fmt.Errorf("type of UNNEST arg is non-array %s", col.Type.SQL()) 722 } 723 // The output of this UNNEST is the non-array version. 724 col.Name = sf.Alias // may be empty 725 col.Type.Array = false 726 727 // Evaluate the expression, and yield a virtual table with one column. 728 e, err := ec.evalExpr(sf.Expr) 729 if err != nil { 730 return ec, nil, fmt.Errorf("evaluating UNNEST arg: %v", err) 731 } 732 arr, ok := e.([]interface{}) 733 if !ok { 734 return ec, nil, fmt.Errorf("evaluating UNNEST arg gave %t, want array", e) 735 } 736 var rows []row 737 for _, v := range arr { 738 rows = append(rows, row{v}) 739 } 740 741 ri := &rawIter{ 742 cols: []colInfo{col}, 743 rows: rows, 744 } 745 ec.cols = ri.cols 746 return ec, ri, nil 747 } 748} 749 750func newJoinIter(lhs, rhs *rawIter, lhsEC, rhsEC evalContext, sfj spansql.SelectFromJoin) (*joinIter, evalContext, error) { 751 if sfj.On != nil && len(sfj.Using) > 0 { 752 return nil, evalContext{}, fmt.Errorf("JOIN may not have both ON and USING clauses") 753 } 754 if sfj.On == nil && len(sfj.Using) == 0 && sfj.Type != spansql.CrossJoin { 755 // TODO: This isn't correct for joining against a non-table. 756 return nil, evalContext{}, fmt.Errorf("non-CROSS JOIN must have ON or USING clause") 757 } 758 759 // Start with the context from the LHS (aliases and params should be the same on both sides). 760 ji := &joinIter{ 761 jt: sfj.Type, 762 ec: lhsEC, 763 764 primary: lhs, 765 secondaryOrig: rhs, 766 767 primaryOffset: 0, 768 secondaryOffset: len(lhsEC.cols), 769 } 770 switch ji.jt { 771 case spansql.LeftJoin: 772 ji.nullPad = true 773 case spansql.RightJoin: 774 ji.nullPad = true 775 // Primary is RHS. 776 ji.ec = rhsEC 777 ji.primary, ji.secondaryOrig = rhs, lhs 778 ji.primaryOffset, ji.secondaryOffset = len(rhsEC.cols), 0 779 case spansql.FullJoin: 780 // FULL JOIN is implemented as a LEFT JOIN with tracking for which rows of the RHS 781 // have been used. Then, at the end of the iteration, the unused RHS rows are emitted. 782 ji.nullPad = true 783 ji.used = make([]bool, 0, 10) // arbitrary preallocation 784 } 785 ji.ec.cols, ji.ec.row = nil, nil 786 787 // Construct a merged evalContext, and prepare the join condition evaluation. 788 // TODO: Remove ambiguous names here? Or catch them when evaluated? 789 // TODO: aliases might need work? 790 if len(sfj.Using) == 0 { 791 ji.prepNonUsing(sfj.On, lhsEC, rhsEC) 792 } else { 793 if err := ji.prepUsing(sfj.Using, lhsEC, rhsEC, ji.jt == spansql.RightJoin); err != nil { 794 return nil, evalContext{}, err 795 } 796 } 797 798 return ji, ji.ec, nil 799} 800 801// prepNonUsing configures the joinIter to evaluate with an ON clause or no join clause. 802// The arg is nil in the latter case. 803func (ji *joinIter) prepNonUsing(on spansql.BoolExpr, lhsEC, rhsEC evalContext) { 804 // Having ON or no clause results in the full set of columns from both sides. 805 // Force a copy. 806 ji.ec.cols = append(ji.ec.cols, lhsEC.cols...) 807 ji.ec.cols = append(ji.ec.cols, rhsEC.cols...) 808 ji.ec.row = make(row, len(ji.ec.cols)) 809 810 ji.cond = func(primary, secondary row) (bool, error) { 811 copy(ji.ec.row[ji.primaryOffset:], primary) 812 copy(ji.ec.row[ji.secondaryOffset:], secondary) 813 if on == nil { 814 // No condition; all rows match. 815 return true, nil 816 } 817 b, err := ji.ec.evalBoolExpr(on) 818 if err != nil { 819 return false, err 820 } 821 return b != nil && *b, nil 822 } 823 ji.zero = func(primary, secondary row) { 824 for i := range ji.ec.row { 825 ji.ec.row[i] = nil 826 } 827 copy(ji.ec.row[ji.primaryOffset:], primary) 828 copy(ji.ec.row[ji.secondaryOffset:], secondary) 829 } 830} 831 832func (ji *joinIter) prepUsing(using []spansql.ID, lhsEC, rhsEC evalContext, flipped bool) error { 833 // Having a USING clause results in the set of named columns once, 834 // followed by the unnamed columns from both sides. 835 836 // lhsUsing is the column indexes in the LHS that the USING clause references. 837 // rhsUsing is similar. 838 // lhsNotUsing/rhsNotUsing are the complement. 839 var lhsUsing, rhsUsing []int 840 var lhsNotUsing, rhsNotUsing []int 841 // lhsUsed, rhsUsed are the set of column indexes in lhsUsing/rhsUsing. 842 lhsUsed, rhsUsed := make(map[int]bool), make(map[int]bool) 843 for _, id := range using { 844 lhsi, err := lhsEC.resolveColumnIndex(id) 845 if err != nil { 846 return err 847 } 848 lhsUsing = append(lhsUsing, lhsi) 849 lhsUsed[lhsi] = true 850 851 rhsi, err := rhsEC.resolveColumnIndex(id) 852 if err != nil { 853 return err 854 } 855 rhsUsing = append(rhsUsing, rhsi) 856 rhsUsed[rhsi] = true 857 858 // TODO: Should this hide or merge column aliases? 859 ji.ec.cols = append(ji.ec.cols, lhsEC.cols[lhsi]) 860 } 861 for i, col := range lhsEC.cols { 862 if !lhsUsed[i] { 863 ji.ec.cols = append(ji.ec.cols, col) 864 lhsNotUsing = append(lhsNotUsing, i) 865 } 866 } 867 for i, col := range rhsEC.cols { 868 if !rhsUsed[i] { 869 ji.ec.cols = append(ji.ec.cols, col) 870 rhsNotUsing = append(rhsNotUsing, i) 871 } 872 } 873 ji.ec.row = make(row, len(ji.ec.cols)) 874 875 primaryUsing, secondaryUsing := lhsUsing, rhsUsing 876 if flipped { 877 primaryUsing, secondaryUsing = secondaryUsing, primaryUsing 878 } 879 880 orNil := func(r row, i int) interface{} { 881 if r == nil { 882 return nil 883 } 884 return r[i] 885 } 886 // populate writes the data to ji.ec.row in the correct positions. 887 populate := func(primary, secondary row) { // either may be nil 888 j := 0 889 if primary != nil { 890 for _, pi := range primaryUsing { 891 ji.ec.row[j] = primary[pi] 892 j++ 893 } 894 } else { 895 for _, si := range secondaryUsing { 896 ji.ec.row[j] = secondary[si] 897 j++ 898 } 899 } 900 lhs, rhs := primary, secondary 901 if flipped { 902 rhs, lhs = lhs, rhs 903 } 904 for _, i := range lhsNotUsing { 905 ji.ec.row[j] = orNil(lhs, i) 906 j++ 907 } 908 for _, i := range rhsNotUsing { 909 ji.ec.row[j] = orNil(rhs, i) 910 j++ 911 } 912 } 913 ji.cond = func(primary, secondary row) (bool, error) { 914 for i, pi := range primaryUsing { 915 si := secondaryUsing[i] 916 if compareVals(primary[pi], secondary[si]) != 0 { 917 return false, nil 918 } 919 } 920 populate(primary, secondary) 921 return true, nil 922 } 923 ji.zero = func(primary, secondary row) { 924 populate(primary, secondary) 925 } 926 return nil 927} 928 929type joinIter struct { 930 jt spansql.JoinType 931 ec evalContext // combined context 932 933 // The "primary" is scanned (consumed), but the secondary is cloned for each primary row. 934 // Most join types have primary==LHS; a RIGHT JOIN is the exception. 935 primary, secondaryOrig *rawIter 936 937 // The offsets into ec.row that the primary/secondary rows should appear 938 // in the final output. Not used when there's a USING clause. 939 primaryOffset, secondaryOffset int 940 // nullPad is whether primary rows without matching secondary rows 941 // should be yielded with null padding (e.g. OUTER JOINs). 942 nullPad bool 943 944 primaryRow row // current row from primary, or nil if it is time to advance 945 secondary *rawIter // current clone of secondary 946 secondaryRead int // number of rows already read from secondary 947 any bool // true if any secondary rows have matched primaryRow 948 949 // cond reports whether the primary and secondary rows "join" (e.g. the ON clause is true). 950 // It populates ec.row with the output. 951 cond func(primary, secondary row) (bool, error) 952 // zero populates ec.row with the primary or secondary row data (either of which may be nil), 953 // and sets the remainder to NULL. 954 // This is used when nullPad is true and a primary or secondary row doesn't match. 955 zero func(primary, secondary row) 956 957 // For FULL JOIN, this tracks the secondary rows that have been used. 958 // It is non-nil when being used. 959 used []bool 960 zeroUnused bool // set when emitting unused secondary rows 961 unusedIndex int // next index of used to check 962} 963 964func (ji *joinIter) Cols() []colInfo { return ji.ec.cols } 965 966func (ji *joinIter) nextPrimary() error { 967 var err error 968 ji.primaryRow, err = ji.primary.Next() 969 if err != nil { 970 return err 971 } 972 ji.secondary = ji.secondaryOrig.clone() 973 ji.secondaryRead = 0 974 ji.any = false 975 return nil 976} 977 978func (ji *joinIter) Next() (row, error) { 979 if ji.primaryRow == nil && !ji.zeroUnused { 980 err := ji.nextPrimary() 981 if err == io.EOF && ji.used != nil { 982 // Drop down to emitting unused secondary rows. 983 ji.zeroUnused = true 984 ji.secondary = nil 985 goto scanJiUsed 986 } 987 if err != nil { 988 return nil, err 989 } 990 } 991scanJiUsed: 992 if ji.zeroUnused { 993 if ji.secondary == nil { 994 ji.secondary = ji.secondaryOrig.clone() 995 ji.secondaryRead = 0 996 } 997 for ji.unusedIndex < len(ji.used) && ji.used[ji.unusedIndex] { 998 ji.unusedIndex++ 999 } 1000 if ji.unusedIndex >= len(ji.used) || ji.secondaryRead >= len(ji.used) { 1001 // Truly finished. 1002 return nil, io.EOF 1003 } 1004 var secondaryRow row 1005 for ji.secondaryRead <= ji.unusedIndex { 1006 var err error 1007 secondaryRow, err = ji.secondary.Next() 1008 if err != nil { 1009 return nil, err 1010 } 1011 ji.secondaryRead++ 1012 } 1013 ji.zero(nil, secondaryRow) 1014 return ji.ec.row, nil 1015 } 1016 1017 for { 1018 secondaryRow, err := ji.secondary.Next() 1019 if err == io.EOF { 1020 // Finished the current primary row. 1021 1022 if !ji.any && ji.nullPad { 1023 ji.zero(ji.primaryRow, nil) 1024 ji.primaryRow = nil 1025 return ji.ec.row, nil 1026 } 1027 1028 // Advance to next one. 1029 err := ji.nextPrimary() 1030 if err == io.EOF && ji.used != nil { 1031 ji.zeroUnused = true 1032 ji.secondary = nil 1033 goto scanJiUsed 1034 } 1035 if err != nil { 1036 return nil, err 1037 } 1038 continue 1039 } 1040 if err != nil { 1041 return nil, err 1042 } 1043 ji.secondaryRead++ 1044 if ji.used != nil { 1045 for len(ji.used) < ji.secondaryRead { 1046 ji.used = append(ji.used, false) 1047 } 1048 } 1049 1050 // We have a pair of rows to consider. 1051 match, err := ji.cond(ji.primaryRow, secondaryRow) 1052 if err != nil { 1053 return nil, err 1054 } 1055 if !match { 1056 continue 1057 } 1058 ji.any = true 1059 if ji.used != nil { 1060 // Make a note that we used this secondary row. 1061 ji.used[ji.secondaryRead-1] = true 1062 } 1063 return ji.ec.row, nil 1064 } 1065} 1066 1067func evalSelectOrder(si *selIter, aux []spansql.Expr) (rows []row, keys [][]interface{}, err error) { 1068 // This is like toRawIter except it also evaluates the auxiliary expressions for ORDER BY. 1069 for { 1070 r, err := si.Next() 1071 if err == io.EOF { 1072 break 1073 } else if err != nil { 1074 return nil, nil, err 1075 } 1076 key, err := si.ec.evalExprList(aux) 1077 if err != nil { 1078 return nil, nil, err 1079 } 1080 1081 rows = append(rows, r.copyAllData()) 1082 keys = append(keys, key) 1083 } 1084 return 1085} 1086 1087// externalRowSorter implements sort.Interface for a slice of rows 1088// with an external sort key. 1089type externalRowSorter struct { 1090 rows []row 1091 keys [][]interface{} 1092 desc []bool // may be nil 1093} 1094 1095func (ers externalRowSorter) Len() int { return len(ers.rows) } 1096func (ers externalRowSorter) Less(i, j int) bool { 1097 return compareValLists(ers.keys[i], ers.keys[j], ers.desc) < 0 1098} 1099func (ers externalRowSorter) Swap(i, j int) { 1100 ers.rows[i], ers.rows[j] = ers.rows[j], ers.rows[i] 1101 ers.keys[i], ers.keys[j] = ers.keys[j], ers.keys[i] 1102} 1103