1/* 2Copyright 2020 Google LLC 3 4Licensed under the Apache License, Version 2.0 (the "License"); 5you may not use this file except in compliance with the License. 6You may obtain a copy of the License at 7 8 http://www.apache.org/licenses/LICENSE-2.0 9 10Unless required by applicable law or agreed to in writing, software 11distributed under the License is distributed on an "AS IS" BASIS, 12WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13See the License for the specific language governing permissions and 14limitations under the License. 15*/ 16 17package spannertest 18 19import ( 20 "fmt" 21 "io" 22 "sort" 23 24 "cloud.google.com/go/spanner/spansql" 25) 26 27/* 28There's several ways to conceptualise SQL queries. The simplest, and what 29we implement here, is a series of pipelines that transform the data, whether 30pulling from a table (FROM tbl), filtering (WHERE expr), re-ordering (ORDER BY expr) 31or other transformations. 32 33The order of operations among those supported by Cloud Spanner is 34 FROM + JOIN + set ops [TODO: set ops] 35 WHERE 36 GROUP BY 37 aggregation 38 HAVING [TODO] 39 SELECT 40 DISTINCT 41 ORDER BY 42 OFFSET 43 LIMIT 44*/ 45 46// rowIter represents some iteration over rows of data. 47// It is returned by reads and queries. 48type rowIter interface { 49 // Cols returns the metadata about the returned data. 50 Cols() []colInfo 51 52 // Next returns the next row. 53 // If done, it returns (nil, io.EOF). 54 Next() (row, error) 55} 56 57// aggSentinel is a synthetic expression that refers to an aggregated value. 58// It is transient only; it is never stored and only used during evaluation. 59type aggSentinel struct { 60 spansql.Expr 61 Type spansql.Type 62 AggIndex int // Index+1 of SELECT list. 63} 64 65// nullIter is a rowIter that returns one empty row only. 66// This is used for queries without a table. 67type nullIter struct { 68 done bool 69} 70 71func (ni *nullIter) Cols() []colInfo { return nil } 72func (ni *nullIter) Next() (row, error) { 73 if ni.done { 74 return nil, io.EOF 75 } 76 ni.done = true 77 return nil, nil 78} 79 80// tableIter is a rowIter that walks a table. 81// It assumes the table is locked for the duration. 82type tableIter struct { 83 t *table 84 rowIndex int // index of next row to return 85 86 alias spansql.ID // if non-empty, "AS <alias>" 87} 88 89func (ti *tableIter) Cols() []colInfo { 90 // Build colInfo in the original column order. 91 cis := make([]colInfo, len(ti.t.cols)) 92 for _, ci := range ti.t.cols { 93 if ti.alias != "" { 94 ci.Alias = spansql.PathExp{ti.alias, ci.Name} 95 } 96 cis[ti.t.origIndex[ci.Name]] = ci 97 } 98 return cis 99} 100 101func (ti *tableIter) Next() (row, error) { 102 if ti.rowIndex >= len(ti.t.rows) { 103 return nil, io.EOF 104 } 105 r := ti.t.rows[ti.rowIndex] 106 ti.rowIndex++ 107 108 // Build output row in the original column order. 109 res := make(row, len(r)) 110 for i, ci := range ti.t.cols { 111 res[ti.t.origIndex[ci.Name]] = r[i] 112 } 113 114 return res, nil 115} 116 117// rawIter is a rowIter with fixed data. 118type rawIter struct { 119 // cols is the metadata about the returned data. 120 cols []colInfo 121 122 // rows holds the result data itself. 123 rows []row 124} 125 126func (raw *rawIter) Cols() []colInfo { return raw.cols } 127func (raw *rawIter) Next() (row, error) { 128 if len(raw.rows) == 0 { 129 return nil, io.EOF 130 } 131 res := raw.rows[0] 132 raw.rows = raw.rows[1:] 133 return res, nil 134} 135 136func (raw *rawIter) add(src row, colIndexes []int) { 137 raw.rows = append(raw.rows, src.copyData(colIndexes)) 138} 139 140// clone makes a shallow copy. 141func (raw *rawIter) clone() *rawIter { 142 return &rawIter{cols: raw.cols, rows: raw.rows} 143} 144 145func toRawIter(ri rowIter) (*rawIter, error) { 146 if raw, ok := ri.(*rawIter); ok { 147 return raw, nil 148 } 149 raw := &rawIter{cols: ri.Cols()} 150 for { 151 row, err := ri.Next() 152 if err == io.EOF { 153 break 154 } else if err != nil { 155 return nil, err 156 } 157 raw.rows = append(raw.rows, row.copyAllData()) 158 } 159 return raw, nil 160} 161 162// whereIter applies a WHERE clause. 163type whereIter struct { 164 ri rowIter 165 ec evalContext 166 where spansql.BoolExpr 167} 168 169func (wi whereIter) Cols() []colInfo { return wi.ri.Cols() } 170func (wi whereIter) Next() (row, error) { 171 for { 172 row, err := wi.ri.Next() 173 if err != nil { 174 return nil, err 175 } 176 wi.ec.row = row 177 178 b, err := wi.ec.evalBoolExpr(wi.where) 179 if err != nil { 180 return nil, err 181 } 182 if b != nil && *b { 183 return row, nil 184 } 185 } 186} 187 188// selIter applies a SELECT list. 189type selIter struct { 190 ri rowIter 191 ec evalContext 192 cis []colInfo 193 list []spansql.Expr 194 195 distinct bool // whether this is a SELECT DISTINCT 196 seen []row 197} 198 199func (si *selIter) Cols() []colInfo { return si.cis } 200func (si *selIter) Next() (row, error) { 201 for { 202 r, err := si.next() 203 if err != nil { 204 return nil, err 205 } 206 if si.distinct && !si.keep(r) { 207 continue 208 } 209 return r, nil 210 } 211} 212 213// next retrieves the next row for the SELECT and evaluates its expression list. 214func (si *selIter) next() (row, error) { 215 r, err := si.ri.Next() 216 if err != nil { 217 return nil, err 218 } 219 si.ec.row = r 220 221 var out row 222 for _, e := range si.list { 223 if e == spansql.Star { 224 out = append(out, r...) 225 } else { 226 v, err := si.ec.evalExpr(e) 227 if err != nil { 228 return nil, err 229 } 230 out = append(out, v) 231 } 232 } 233 return out, nil 234} 235 236func (si *selIter) keep(r row) bool { 237 // This is hilariously inefficient; O(N^2) in the number of returned rows. 238 // Some sort of hashing could be done to deduplicate instead. 239 // This also breaks on array/struct types. 240 for _, prev := range si.seen { 241 if rowEqual(prev, r) { 242 return false 243 } 244 } 245 si.seen = append(si.seen, r) 246 return true 247} 248 249// offsetIter applies an OFFSET clause. 250type offsetIter struct { 251 ri rowIter 252 skip int64 253} 254 255func (oi *offsetIter) Cols() []colInfo { return oi.ri.Cols() } 256func (oi *offsetIter) Next() (row, error) { 257 for oi.skip > 0 { 258 _, err := oi.ri.Next() 259 if err != nil { 260 return nil, err 261 } 262 oi.skip-- 263 } 264 row, err := oi.ri.Next() 265 if err != nil { 266 return nil, err 267 } 268 return row, nil 269} 270 271// limitIter applies a LIMIT clause. 272type limitIter struct { 273 ri rowIter 274 rem int64 275} 276 277func (li *limitIter) Cols() []colInfo { return li.ri.Cols() } 278func (li *limitIter) Next() (row, error) { 279 if li.rem <= 0 { 280 return nil, io.EOF 281 } 282 row, err := li.ri.Next() 283 if err != nil { 284 return nil, err 285 } 286 li.rem-- 287 return row, nil 288} 289 290type queryParam struct { 291 Value interface{} // internal representation 292 Type spansql.Type 293} 294 295type queryParams map[string]queryParam // TODO: change key to spansql.Param? 296 297type queryContext struct { 298 params queryParams 299 300 tables []*table // sorted by name 301 tableIndex map[spansql.ID]*table 302 locks int 303} 304 305func (qc *queryContext) Lock() { 306 // Take locks in name order. 307 for _, t := range qc.tables { 308 t.mu.Lock() 309 qc.locks++ 310 } 311} 312 313func (qc *queryContext) Unlock() { 314 for _, t := range qc.tables { 315 t.mu.Unlock() 316 qc.locks-- 317 } 318} 319 320func (d *database) Query(q spansql.Query, params queryParams) (ri rowIter, err error) { 321 // Figure out the context of the query and take any required locks. 322 qc, err := d.queryContext(q, params) 323 if err != nil { 324 return nil, err 325 } 326 qc.Lock() 327 // On the way out, if there were locks taken, flatten the output 328 // and release the locks. 329 if qc.locks > 0 { 330 defer func() { 331 if err == nil { 332 ri, err = toRawIter(ri) 333 } 334 qc.Unlock() 335 }() 336 } 337 338 // Prepare auxiliary expressions to evaluate for ORDER BY. 339 var aux []spansql.Expr 340 var desc []bool 341 for _, o := range q.Order { 342 aux = append(aux, o.Expr) 343 desc = append(desc, o.Desc) 344 } 345 346 si, err := d.evalSelect(q.Select, qc) 347 if err != nil { 348 return nil, err 349 } 350 ri = si 351 352 // Apply ORDER BY. 353 if len(q.Order) > 0 { 354 // Evaluate the selIter completely, and sort the rows by the auxiliary expressions. 355 rows, keys, err := evalSelectOrder(si, aux) 356 if err != nil { 357 return nil, err 358 } 359 sort.Sort(externalRowSorter{rows: rows, keys: keys, desc: desc}) 360 ri = &rawIter{cols: si.cis, rows: rows} 361 } 362 363 // Apply LIMIT, OFFSET. 364 if q.Limit != nil { 365 if q.Offset != nil { 366 off, err := evalLiteralOrParam(q.Offset, params) 367 if err != nil { 368 return nil, err 369 } 370 ri = &offsetIter{ri: ri, skip: off} 371 } 372 373 lim, err := evalLiteralOrParam(q.Limit, params) 374 if err != nil { 375 return nil, err 376 } 377 ri = &limitIter{ri: ri, rem: lim} 378 } 379 380 return ri, nil 381} 382 383func (d *database) queryContext(q spansql.Query, params queryParams) (*queryContext, error) { 384 qc := &queryContext{ 385 params: params, 386 } 387 388 // Look for any mentioned tables and add them to qc.tableIndex. 389 addTable := func(name spansql.ID) error { 390 if _, ok := qc.tableIndex[name]; ok { 391 return nil // Already found this table. 392 } 393 t, err := d.table(name) 394 if err != nil { 395 return err 396 } 397 if qc.tableIndex == nil { 398 qc.tableIndex = make(map[spansql.ID]*table) 399 } 400 qc.tableIndex[name] = t 401 return nil 402 } 403 var findTables func(sf spansql.SelectFrom) error 404 findTables = func(sf spansql.SelectFrom) error { 405 switch sf := sf.(type) { 406 default: 407 return fmt.Errorf("can't prepare query context for SelectFrom of type %T", sf) 408 case spansql.SelectFromTable: 409 return addTable(sf.Table) 410 case spansql.SelectFromJoin: 411 if err := findTables(sf.LHS); err != nil { 412 return err 413 } 414 return findTables(sf.RHS) 415 } 416 } 417 for _, sf := range q.Select.From { 418 if err := findTables(sf); err != nil { 419 return nil, err 420 } 421 } 422 423 // Build qc.tables in name order so we can take locks in a well-defined order. 424 var names []spansql.ID 425 for name := range qc.tableIndex { 426 names = append(names, name) 427 } 428 sort.Slice(names, func(i, j int) bool { return names[i] < names[j] }) 429 for _, name := range names { 430 qc.tables = append(qc.tables, qc.tableIndex[name]) 431 } 432 433 return qc, nil 434} 435 436func (d *database) evalSelect(sel spansql.Select, qc *queryContext) (si *selIter, evalErr error) { 437 var ri rowIter = &nullIter{} 438 ec := evalContext{ 439 params: qc.params, 440 } 441 442 // First stage is to identify the data source. 443 // If there's a FROM then that names a table to use. 444 if len(sel.From) > 1 { 445 return nil, fmt.Errorf("selecting with more than one FROM clause not yet supported") 446 } 447 if len(sel.From) == 1 { 448 var err error 449 ec, ri, err = d.evalSelectFrom(qc, ec, sel.From[0]) 450 if err != nil { 451 return nil, err 452 } 453 } 454 455 // Apply WHERE. 456 if sel.Where != nil { 457 ri = whereIter{ 458 ri: ri, 459 ec: ec, 460 where: sel.Where, 461 } 462 } 463 464 // Load aliases visible to any future iterators, 465 // including GROUP BY and ORDER BY. These are not visible to the WHERE clause. 466 ec.aliases = make(map[spansql.ID]spansql.Expr) 467 for i, alias := range sel.ListAliases { 468 ec.aliases[alias] = sel.List[i] 469 } 470 // TODO: Add aliases for "1", "2", etc. 471 472 // Apply GROUP BY. 473 // This only reorders rows to group rows together; 474 // aggregation happens next. 475 var rowGroups [][2]int // Sequence of half-open intervals of row numbers. 476 if len(sel.GroupBy) > 0 { 477 raw, err := toRawIter(ri) 478 if err != nil { 479 return nil, err 480 } 481 keys := make([][]interface{}, 0, len(raw.rows)) 482 for _, row := range raw.rows { 483 // Evaluate sort key for this row. 484 ec.row = row 485 key, err := ec.evalExprList(sel.GroupBy) 486 if err != nil { 487 return nil, err 488 } 489 keys = append(keys, key) 490 } 491 492 // Reorder rows base on the evaluated keys. 493 ers := externalRowSorter{rows: raw.rows, keys: keys} 494 sort.Sort(ers) 495 raw.rows = ers.rows 496 497 // Record groups as a sequence of row intervals. 498 // Each group is a run of the same keys. 499 start := 0 500 for i := 1; i < len(keys); i++ { 501 if compareValLists(keys[i-1], keys[i], nil) == 0 { 502 continue 503 } 504 rowGroups = append(rowGroups, [2]int{start, i}) 505 start = i 506 } 507 if len(keys) > 0 { 508 rowGroups = append(rowGroups, [2]int{start, len(keys)}) 509 } 510 511 // Clear aliases, since they aren't visible elsewhere. 512 ec.aliases = nil 513 514 ri = raw 515 } 516 517 // Handle aggregation. 518 // TODO: Support more than one aggregation function; does Spanner support that? 519 aggI := -1 520 for i, e := range sel.List { 521 // Supported aggregate funcs have exactly one arg. 522 f, ok := e.(spansql.Func) 523 if !ok || len(f.Args) != 1 { 524 continue 525 } 526 _, ok = aggregateFuncs[f.Name] 527 if !ok { 528 continue 529 } 530 if aggI > -1 { 531 return nil, fmt.Errorf("only one aggregate function is supported") 532 } 533 aggI = i 534 } 535 if aggI > -1 { 536 raw, err := toRawIter(ri) 537 if err != nil { 538 return nil, err 539 } 540 if len(sel.GroupBy) == 0 { 541 // No grouping, so aggregation applies to the entire table (e.g. COUNT(*)). 542 // This may result in a [0,0) entry for empty inputs. 543 rowGroups = [][2]int{{0, len(raw.rows)}} 544 } 545 fexpr := sel.List[aggI].(spansql.Func) 546 fn := aggregateFuncs[fexpr.Name] 547 starArg := fexpr.Args[0] == spansql.Star 548 if starArg && !fn.AcceptStar { 549 return nil, fmt.Errorf("aggregate function %s does not accept * as an argument", fexpr.Name) 550 } 551 var argType spansql.Type 552 if !starArg { 553 ci, err := ec.colInfo(fexpr.Args[0]) 554 if err != nil { 555 return nil, err 556 } 557 argType = ci.Type 558 } 559 560 // Prepare output. 561 rawOut := &rawIter{ 562 // Same as input columns, but also the aggregate value. 563 // Add the colInfo for the aggregate at the end 564 // so we know the type. 565 // Make a copy for safety. 566 cols: append([]colInfo(nil), raw.cols...), 567 } 568 569 var aggType spansql.Type 570 for _, rg := range rowGroups { 571 // Compute aggregate value across this group. 572 var values []interface{} 573 for i := rg[0]; i < rg[1]; i++ { 574 ec.row = raw.rows[i] 575 if starArg { 576 // A non-NULL placeholder is sufficient for aggregation. 577 values = append(values, 1) 578 } else { 579 x, err := ec.evalExpr(fexpr.Args[0]) 580 if err != nil { 581 return nil, err 582 } 583 values = append(values, x) 584 } 585 } 586 x, typ, err := fn.Eval(values, argType) 587 if err != nil { 588 return nil, err 589 } 590 aggType = typ 591 592 var outRow row 593 // Output for the row group is the first row of the group (arbitrary, 594 // but it should be representative), and the aggregate value. 595 // TODO: Should this exclude the aggregated expressions so they can't be selected? 596 // If the row group is empty then only the aggregation value is used; 597 // this covers things like COUNT(*) with no matching rows. 598 if rg[0] < len(raw.rows) { 599 repRow := raw.rows[rg[0]] 600 for i := range repRow { 601 outRow = append(outRow, repRow.copyDataElem(i)) 602 } 603 } else { 604 // Fill with NULLs to keep the rows and colInfo aligned. 605 for i := 0; i < len(rawOut.cols); i++ { 606 outRow = append(outRow, nil) 607 } 608 } 609 outRow = append(outRow, x) 610 rawOut.rows = append(rawOut.rows, outRow) 611 } 612 613 if aggType == (spansql.Type{}) { 614 // Fallback; there might not be any groups. 615 // TODO: Should this be in aggregateFunc? 616 aggType = int64Type 617 } 618 rawOut.cols = append(raw.cols, colInfo{ 619 Name: spansql.ID(fexpr.SQL()), // TODO: this is a bit hokey, but it is output only 620 Type: aggType, 621 AggIndex: aggI + 1, 622 }) 623 624 ri = rawOut 625 ec.cols = rawOut.cols 626 sel.List[aggI] = aggSentinel{ // Mutate query so evalExpr in selIter picks out the new value. 627 Type: aggType, 628 AggIndex: aggI + 1, 629 } 630 } 631 632 // TODO: Support table sampling. 633 634 // Apply SELECT list. 635 var colInfos []colInfo 636 for i, e := range sel.List { 637 if e == spansql.Star { 638 colInfos = append(colInfos, ec.cols...) 639 } else { 640 ci, err := ec.colInfo(e) 641 if err != nil { 642 return nil, err 643 } 644 if len(sel.ListAliases) > 0 { 645 alias := sel.ListAliases[i] 646 if alias != "" { 647 ci.Name = alias 648 } 649 } 650 // TODO: deal with ci.Name == ""? 651 colInfos = append(colInfos, ci) 652 } 653 } 654 655 return &selIter{ 656 ri: ri, 657 ec: ec, 658 cis: colInfos, 659 list: sel.List, 660 661 distinct: sel.Distinct, // Apply DISTINCT. 662 }, nil 663} 664 665func (d *database) evalSelectFrom(qc *queryContext, ec evalContext, sf spansql.SelectFrom) (evalContext, rowIter, error) { 666 switch sf := sf.(type) { 667 default: 668 return ec, nil, fmt.Errorf("selecting with FROM clause of type %T not yet supported", sf) 669 case spansql.SelectFromTable: 670 t, ok := qc.tableIndex[sf.Table] 671 if !ok { 672 // This shouldn't be possible; the queryContext should have discovered missing tables already. 673 return ec, nil, fmt.Errorf("unknown table %q", sf.Table) 674 } 675 ti := &tableIter{t: t} 676 if sf.Alias != "" { 677 ti.alias = sf.Alias 678 } else { 679 // There is an implicit alias using the table name. 680 // https://cloud.google.com/spanner/docs/query-syntax#implicit_aliases 681 ti.alias = sf.Table 682 } 683 ec.cols = ti.Cols() 684 return ec, ti, nil 685 case spansql.SelectFromJoin: 686 // TODO: Avoid the toRawIter calls here by doing the RHS recursive evalSelectFrom in joinIter.Next on demand. 687 688 lhsEC, lhs, err := d.evalSelectFrom(qc, ec, sf.LHS) 689 if err != nil { 690 return ec, nil, err 691 } 692 lhsRaw, err := toRawIter(lhs) 693 if err != nil { 694 return ec, nil, err 695 } 696 697 rhsEC, rhs, err := d.evalSelectFrom(qc, ec, sf.RHS) 698 if err != nil { 699 return ec, nil, err 700 } 701 rhsRaw, err := toRawIter(rhs) 702 if err != nil { 703 return ec, nil, err 704 } 705 706 ji, ec, err := newJoinIter(lhsRaw, rhsRaw, lhsEC, rhsEC, sf) 707 if err != nil { 708 return ec, nil, err 709 } 710 return ec, ji, nil 711 } 712} 713 714func newJoinIter(lhs, rhs *rawIter, lhsEC, rhsEC evalContext, sfj spansql.SelectFromJoin) (*joinIter, evalContext, error) { 715 if sfj.On != nil && len(sfj.Using) > 0 { 716 return nil, evalContext{}, fmt.Errorf("JOIN may not have both ON and USING clauses") 717 } 718 if sfj.On == nil && len(sfj.Using) == 0 && sfj.Type != spansql.CrossJoin { 719 // TODO: This isn't correct for joining against a non-table. 720 return nil, evalContext{}, fmt.Errorf("non-CROSS JOIN must have ON or USING clause") 721 } 722 723 // Start with the context from the LHS (aliases and params should be the same on both sides). 724 ji := &joinIter{ 725 jt: sfj.Type, 726 ec: lhsEC, 727 728 primary: lhs, 729 secondaryOrig: rhs, 730 731 primaryOffset: 0, 732 secondaryOffset: len(lhsEC.cols), 733 } 734 switch ji.jt { 735 case spansql.LeftJoin: 736 ji.nullPad = true 737 case spansql.RightJoin: 738 ji.nullPad = true 739 // Primary is RHS. 740 ji.ec = rhsEC 741 ji.primary, ji.secondaryOrig = rhs, lhs 742 ji.primaryOffset, ji.secondaryOffset = len(rhsEC.cols), 0 743 case spansql.FullJoin: 744 return nil, evalContext{}, fmt.Errorf("TODO: can't yet evaluate FULL JOIN") 745 } 746 ji.ec.cols, ji.ec.row = nil, nil 747 748 // Construct a merged evalContext, and prepare the join condition evaluation. 749 // TODO: Remove ambiguous names here? Or catch them when evaluated? 750 // TODO: aliases might need work? 751 if len(sfj.Using) == 0 { 752 ji.prepNonUsing(sfj.On, lhsEC, rhsEC) 753 } else { 754 if err := ji.prepUsing(sfj.Using, lhsEC, rhsEC, ji.jt == spansql.RightJoin); err != nil { 755 return nil, evalContext{}, err 756 } 757 } 758 759 return ji, ji.ec, nil 760} 761 762// prepNonUsing configures the joinIter to evaluate with an ON clause or no join clause. 763// The arg is nil in the latter case. 764func (ji *joinIter) prepNonUsing(on spansql.BoolExpr, lhsEC, rhsEC evalContext) { 765 // Having ON or no clause results in the full set of columns from both sides. 766 // Force a copy. 767 ji.ec.cols = append(ji.ec.cols, lhsEC.cols...) 768 ji.ec.cols = append(ji.ec.cols, rhsEC.cols...) 769 ji.ec.row = make(row, len(ji.ec.cols)) 770 771 ji.cond = func(primary, secondary row) (bool, error) { 772 copy(ji.ec.row[ji.primaryOffset:], primary) 773 copy(ji.ec.row[ji.secondaryOffset:], secondary) 774 if on == nil { 775 // No condition; all rows match. 776 return true, nil 777 } 778 b, err := ji.ec.evalBoolExpr(on) 779 if err != nil { 780 return false, err 781 } 782 return b != nil && *b, nil 783 } 784 ji.zero = func(primary row) { 785 for i := range ji.ec.row { 786 ji.ec.row[i] = nil 787 } 788 copy(ji.ec.row[ji.primaryOffset:], primary) 789 } 790} 791 792func (ji *joinIter) prepUsing(using []spansql.ID, lhsEC, rhsEC evalContext, flipped bool) error { 793 // Having a USING clause results in the set of named columns once, 794 // followed by the unnamed columns from both sides. 795 796 // lhsUsing is the column indexes in the LHS that the USING clause references. 797 // rhsUsing is similar. 798 // lhsNotUsing/rhsNotUsing are the complement. 799 var lhsUsing, rhsUsing []int 800 var lhsNotUsing, rhsNotUsing []int 801 // lhsUsed, rhsUsed are the set of column indexes in lhsUsing/rhsUsing. 802 lhsUsed, rhsUsed := make(map[int]bool), make(map[int]bool) 803 for _, id := range using { 804 lhsi, err := lhsEC.resolveColumnIndex(id) 805 if err != nil { 806 return err 807 } 808 lhsUsing = append(lhsUsing, lhsi) 809 lhsUsed[lhsi] = true 810 811 rhsi, err := rhsEC.resolveColumnIndex(id) 812 if err != nil { 813 return err 814 } 815 rhsUsing = append(rhsUsing, rhsi) 816 rhsUsed[rhsi] = true 817 818 // TODO: Should this hide or merge column aliases? 819 ji.ec.cols = append(ji.ec.cols, lhsEC.cols[lhsi]) 820 } 821 for i, col := range lhsEC.cols { 822 if !lhsUsed[i] { 823 ji.ec.cols = append(ji.ec.cols, col) 824 lhsNotUsing = append(lhsNotUsing, i) 825 } 826 } 827 for i, col := range rhsEC.cols { 828 if !rhsUsed[i] { 829 ji.ec.cols = append(ji.ec.cols, col) 830 rhsNotUsing = append(rhsNotUsing, i) 831 } 832 } 833 ji.ec.row = make(row, len(ji.ec.cols)) 834 835 primaryUsing, secondaryUsing := lhsUsing, rhsUsing 836 if flipped { 837 primaryUsing, secondaryUsing = secondaryUsing, primaryUsing 838 } 839 840 orNil := func(r row, i int) interface{} { 841 if r == nil { 842 return nil 843 } 844 return r[i] 845 } 846 // populate writes the data to ji.ec.row in the correct positions. 847 populate := func(primary, secondary row) { // secondary may be nil 848 j := 0 849 for _, pi := range primaryUsing { 850 ji.ec.row[j] = primary[pi] 851 j++ 852 } 853 lhs, rhs := primary, secondary 854 if flipped { 855 rhs, lhs = lhs, rhs 856 } 857 for _, i := range lhsNotUsing { 858 ji.ec.row[j] = orNil(lhs, i) 859 j++ 860 } 861 for _, i := range rhsNotUsing { 862 ji.ec.row[j] = orNil(rhs, i) 863 j++ 864 } 865 } 866 ji.cond = func(primary, secondary row) (bool, error) { 867 for i, pi := range primaryUsing { 868 si := secondaryUsing[i] 869 if compareVals(primary[pi], secondary[si]) != 0 { 870 return false, nil 871 } 872 } 873 populate(primary, secondary) 874 return true, nil 875 } 876 ji.zero = func(primary row) { 877 populate(primary, nil) 878 } 879 return nil 880} 881 882type joinIter struct { 883 jt spansql.JoinType 884 ec evalContext // combined context 885 886 // The "primary" is scanned (consumed), but the secondary is cloned for each primary row. 887 // Most join types have primary==LHS; a RIGHT JOIN is the exception. 888 primary, secondaryOrig *rawIter 889 890 // The offsets into ec.row that the primary/secondary rows should appear 891 // in the final output. Not used when there's a USING clause. 892 primaryOffset, secondaryOffset int 893 // nullPad is whether primary rows without matching secondary rows 894 // should be yielded with null padding (e.g. OUTER JOINs). 895 nullPad bool 896 897 primaryRow row // current row from primary, or nil if it is time to advance 898 secondary *rawIter // current clone of secondary 899 any bool // true if any secondary rows have matched primaryRow 900 901 // cond reports whether the primary and secondary rows "join" (e.g. the ON clause is true). 902 // It populates ec.row with the output. 903 cond func(primary, secondary row) (bool, error) 904 // zero populates ec.row with the primary row and sets the remainder to NULL. 905 // This is used when nullPad is true and a primary row doesn't match any secondary row. 906 zero func(primary row) 907} 908 909func (ji *joinIter) Cols() []colInfo { return ji.ec.cols } 910 911func (ji *joinIter) nextPrimary() error { 912 var err error 913 ji.primaryRow, err = ji.primary.Next() 914 if err != nil { 915 return err 916 } 917 ji.secondary = ji.secondaryOrig.clone() 918 ji.any = false 919 return nil 920} 921 922func (ji *joinIter) Next() (row, error) { 923 if ji.primaryRow == nil { 924 if err := ji.nextPrimary(); err != nil { 925 return nil, err 926 } 927 } 928 929 for { 930 secondaryRow, err := ji.secondary.Next() 931 if err == io.EOF { 932 // Finished the current primary row. 933 934 if !ji.any && ji.nullPad { 935 ji.zero(ji.primaryRow) 936 ji.primaryRow = nil 937 return ji.ec.row, nil 938 } 939 940 // Advance to next one. 941 if err := ji.nextPrimary(); err != nil { 942 return nil, err 943 } 944 continue 945 } 946 if err != nil { 947 return nil, err 948 } 949 950 // We have a pair of rows to consider. 951 match, err := ji.cond(ji.primaryRow, secondaryRow) 952 if err != nil { 953 return nil, err 954 } 955 if !match { 956 continue 957 } 958 ji.any = true 959 return ji.ec.row, nil 960 } 961} 962 963func evalSelectOrder(si *selIter, aux []spansql.Expr) (rows []row, keys [][]interface{}, err error) { 964 // This is like toRawIter except it also evaluates the auxiliary expressions for ORDER BY. 965 for { 966 r, err := si.Next() 967 if err == io.EOF { 968 break 969 } else if err != nil { 970 return nil, nil, err 971 } 972 key, err := si.ec.evalExprList(aux) 973 if err != nil { 974 return nil, nil, err 975 } 976 977 rows = append(rows, r.copyAllData()) 978 keys = append(keys, key) 979 } 980 return 981} 982 983// externalRowSorter implements sort.Interface for a slice of rows 984// with an external sort key. 985type externalRowSorter struct { 986 rows []row 987 keys [][]interface{} 988 desc []bool // may be nil 989} 990 991func (ers externalRowSorter) Len() int { return len(ers.rows) } 992func (ers externalRowSorter) Less(i, j int) bool { 993 return compareValLists(ers.keys[i], ers.keys[j], ers.desc) < 0 994} 995func (ers externalRowSorter) Swap(i, j int) { 996 ers.rows[i], ers.rows[j] = ers.rows[j], ers.rows[i] 997 ers.keys[i], ers.keys[j] = ers.keys[j], ers.keys[i] 998} 999