1/*
2Copyright 2020 Google LLC
3
4Licensed under the Apache License, Version 2.0 (the "License");
5you may not use this file except in compliance with the License.
6You may obtain a copy of the License at
7
8    http://www.apache.org/licenses/LICENSE-2.0
9
10Unless required by applicable law or agreed to in writing, software
11distributed under the License is distributed on an "AS IS" BASIS,
12WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13See the License for the specific language governing permissions and
14limitations under the License.
15*/
16
17package spannertest
18
19import (
20	"fmt"
21	"io"
22	"sort"
23
24	"cloud.google.com/go/spanner/spansql"
25)
26
27/*
28There's several ways to conceptualise SQL queries. The simplest, and what
29we implement here, is a series of pipelines that transform the data, whether
30pulling from a table (FROM tbl), filtering (WHERE expr), re-ordering (ORDER BY expr)
31or other transformations.
32
33The order of operations among those supported by Cloud Spanner is
34	FROM + JOIN + set ops [TODO: set ops]
35	WHERE
36	GROUP BY
37	aggregation
38	HAVING [TODO]
39	SELECT
40	DISTINCT
41	ORDER BY
42	OFFSET
43	LIMIT
44*/
45
46// rowIter represents some iteration over rows of data.
47// It is returned by reads and queries.
48type rowIter interface {
49	// Cols returns the metadata about the returned data.
50	Cols() []colInfo
51
52	// Next returns the next row.
53	// If done, it returns (nil, io.EOF).
54	Next() (row, error)
55}
56
57// aggSentinel is a synthetic expression that refers to an aggregated value.
58// It is transient only; it is never stored and only used during evaluation.
59type aggSentinel struct {
60	spansql.Expr
61	Type     spansql.Type
62	AggIndex int // Index+1 of SELECT list.
63}
64
65// nullIter is a rowIter that returns one empty row only.
66// This is used for queries without a table.
67type nullIter struct {
68	done bool
69}
70
71func (ni *nullIter) Cols() []colInfo { return nil }
72func (ni *nullIter) Next() (row, error) {
73	if ni.done {
74		return nil, io.EOF
75	}
76	ni.done = true
77	return nil, nil
78}
79
80// tableIter is a rowIter that walks a table.
81// It assumes the table is locked for the duration.
82type tableIter struct {
83	t        *table
84	rowIndex int // index of next row to return
85
86	alias spansql.ID // if non-empty, "AS <alias>"
87}
88
89func (ti *tableIter) Cols() []colInfo {
90	// Build colInfo in the original column order.
91	cis := make([]colInfo, len(ti.t.cols))
92	for _, ci := range ti.t.cols {
93		if ti.alias != "" {
94			ci.Alias = spansql.PathExp{ti.alias, ci.Name}
95		}
96		cis[ti.t.origIndex[ci.Name]] = ci
97	}
98	return cis
99}
100
101func (ti *tableIter) Next() (row, error) {
102	if ti.rowIndex >= len(ti.t.rows) {
103		return nil, io.EOF
104	}
105	r := ti.t.rows[ti.rowIndex]
106	ti.rowIndex++
107
108	// Build output row in the original column order.
109	res := make(row, len(r))
110	for i, ci := range ti.t.cols {
111		res[ti.t.origIndex[ci.Name]] = r[i]
112	}
113
114	return res, nil
115}
116
117// rawIter is a rowIter with fixed data.
118type rawIter struct {
119	// cols is the metadata about the returned data.
120	cols []colInfo
121
122	// rows holds the result data itself.
123	rows []row
124}
125
126func (raw *rawIter) Cols() []colInfo { return raw.cols }
127func (raw *rawIter) Next() (row, error) {
128	if len(raw.rows) == 0 {
129		return nil, io.EOF
130	}
131	res := raw.rows[0]
132	raw.rows = raw.rows[1:]
133	return res, nil
134}
135
136func (raw *rawIter) add(src row, colIndexes []int) {
137	raw.rows = append(raw.rows, src.copyData(colIndexes))
138}
139
140// clone makes a shallow copy.
141func (raw *rawIter) clone() *rawIter {
142	return &rawIter{cols: raw.cols, rows: raw.rows}
143}
144
145func toRawIter(ri rowIter) (*rawIter, error) {
146	if raw, ok := ri.(*rawIter); ok {
147		return raw, nil
148	}
149	raw := &rawIter{cols: ri.Cols()}
150	for {
151		row, err := ri.Next()
152		if err == io.EOF {
153			break
154		} else if err != nil {
155			return nil, err
156		}
157		raw.rows = append(raw.rows, row.copyAllData())
158	}
159	return raw, nil
160}
161
162// whereIter applies a WHERE clause.
163type whereIter struct {
164	ri    rowIter
165	ec    evalContext
166	where spansql.BoolExpr
167}
168
169func (wi whereIter) Cols() []colInfo { return wi.ri.Cols() }
170func (wi whereIter) Next() (row, error) {
171	for {
172		row, err := wi.ri.Next()
173		if err != nil {
174			return nil, err
175		}
176		wi.ec.row = row
177
178		b, err := wi.ec.evalBoolExpr(wi.where)
179		if err != nil {
180			return nil, err
181		}
182		if b != nil && *b {
183			return row, nil
184		}
185	}
186}
187
188// selIter applies a SELECT list.
189type selIter struct {
190	ri   rowIter
191	ec   evalContext
192	cis  []colInfo
193	list []spansql.Expr
194
195	distinct bool // whether this is a SELECT DISTINCT
196	seen     []row
197}
198
199func (si *selIter) Cols() []colInfo { return si.cis }
200func (si *selIter) Next() (row, error) {
201	for {
202		r, err := si.next()
203		if err != nil {
204			return nil, err
205		}
206		if si.distinct && !si.keep(r) {
207			continue
208		}
209		return r, nil
210	}
211}
212
213// next retrieves the next row for the SELECT and evaluates its expression list.
214func (si *selIter) next() (row, error) {
215	r, err := si.ri.Next()
216	if err != nil {
217		return nil, err
218	}
219	si.ec.row = r
220
221	var out row
222	for _, e := range si.list {
223		if e == spansql.Star {
224			out = append(out, r...)
225		} else {
226			v, err := si.ec.evalExpr(e)
227			if err != nil {
228				return nil, err
229			}
230			out = append(out, v)
231		}
232	}
233	return out, nil
234}
235
236func (si *selIter) keep(r row) bool {
237	// This is hilariously inefficient; O(N^2) in the number of returned rows.
238	// Some sort of hashing could be done to deduplicate instead.
239	// This also breaks on array/struct types.
240	for _, prev := range si.seen {
241		if rowEqual(prev, r) {
242			return false
243		}
244	}
245	si.seen = append(si.seen, r)
246	return true
247}
248
249// offsetIter applies an OFFSET clause.
250type offsetIter struct {
251	ri   rowIter
252	skip int64
253}
254
255func (oi *offsetIter) Cols() []colInfo { return oi.ri.Cols() }
256func (oi *offsetIter) Next() (row, error) {
257	for oi.skip > 0 {
258		_, err := oi.ri.Next()
259		if err != nil {
260			return nil, err
261		}
262		oi.skip--
263	}
264	row, err := oi.ri.Next()
265	if err != nil {
266		return nil, err
267	}
268	return row, nil
269}
270
271// limitIter applies a LIMIT clause.
272type limitIter struct {
273	ri  rowIter
274	rem int64
275}
276
277func (li *limitIter) Cols() []colInfo { return li.ri.Cols() }
278func (li *limitIter) Next() (row, error) {
279	if li.rem <= 0 {
280		return nil, io.EOF
281	}
282	row, err := li.ri.Next()
283	if err != nil {
284		return nil, err
285	}
286	li.rem--
287	return row, nil
288}
289
290type queryParam struct {
291	Value interface{} // internal representation
292	Type  spansql.Type
293}
294
295type queryParams map[string]queryParam // TODO: change key to spansql.Param?
296
297type queryContext struct {
298	params queryParams
299
300	tables     []*table // sorted by name
301	tableIndex map[spansql.ID]*table
302	locks      int
303}
304
305func (qc *queryContext) Lock() {
306	// Take locks in name order.
307	for _, t := range qc.tables {
308		t.mu.Lock()
309		qc.locks++
310	}
311}
312
313func (qc *queryContext) Unlock() {
314	for _, t := range qc.tables {
315		t.mu.Unlock()
316		qc.locks--
317	}
318}
319
320func (d *database) Query(q spansql.Query, params queryParams) (ri rowIter, err error) {
321	// Figure out the context of the query and take any required locks.
322	qc, err := d.queryContext(q, params)
323	if err != nil {
324		return nil, err
325	}
326	qc.Lock()
327	// On the way out, if there were locks taken, flatten the output
328	// and release the locks.
329	if qc.locks > 0 {
330		defer func() {
331			if err == nil {
332				ri, err = toRawIter(ri)
333			}
334			qc.Unlock()
335		}()
336	}
337
338	// Prepare auxiliary expressions to evaluate for ORDER BY.
339	var aux []spansql.Expr
340	var desc []bool
341	for _, o := range q.Order {
342		aux = append(aux, o.Expr)
343		desc = append(desc, o.Desc)
344	}
345
346	si, err := d.evalSelect(q.Select, qc)
347	if err != nil {
348		return nil, err
349	}
350	ri = si
351
352	// Apply ORDER BY.
353	if len(q.Order) > 0 {
354		// Evaluate the selIter completely, and sort the rows by the auxiliary expressions.
355		rows, keys, err := evalSelectOrder(si, aux)
356		if err != nil {
357			return nil, err
358		}
359		sort.Sort(externalRowSorter{rows: rows, keys: keys, desc: desc})
360		ri = &rawIter{cols: si.cis, rows: rows}
361	}
362
363	// Apply LIMIT, OFFSET.
364	if q.Limit != nil {
365		if q.Offset != nil {
366			off, err := evalLiteralOrParam(q.Offset, params)
367			if err != nil {
368				return nil, err
369			}
370			ri = &offsetIter{ri: ri, skip: off}
371		}
372
373		lim, err := evalLiteralOrParam(q.Limit, params)
374		if err != nil {
375			return nil, err
376		}
377		ri = &limitIter{ri: ri, rem: lim}
378	}
379
380	return ri, nil
381}
382
383func (d *database) queryContext(q spansql.Query, params queryParams) (*queryContext, error) {
384	qc := &queryContext{
385		params: params,
386	}
387
388	// Look for any mentioned tables and add them to qc.tableIndex.
389	addTable := func(name spansql.ID) error {
390		if _, ok := qc.tableIndex[name]; ok {
391			return nil // Already found this table.
392		}
393		t, err := d.table(name)
394		if err != nil {
395			return err
396		}
397		if qc.tableIndex == nil {
398			qc.tableIndex = make(map[spansql.ID]*table)
399		}
400		qc.tableIndex[name] = t
401		return nil
402	}
403	var findTables func(sf spansql.SelectFrom) error
404	findTables = func(sf spansql.SelectFrom) error {
405		switch sf := sf.(type) {
406		default:
407			return fmt.Errorf("can't prepare query context for SelectFrom of type %T", sf)
408		case spansql.SelectFromTable:
409			return addTable(sf.Table)
410		case spansql.SelectFromJoin:
411			if err := findTables(sf.LHS); err != nil {
412				return err
413			}
414			return findTables(sf.RHS)
415		case spansql.SelectFromUnnest:
416			// TODO: if array paths get supported, this will need more work.
417			return nil
418		}
419	}
420	for _, sf := range q.Select.From {
421		if err := findTables(sf); err != nil {
422			return nil, err
423		}
424	}
425
426	// Build qc.tables in name order so we can take locks in a well-defined order.
427	var names []spansql.ID
428	for name := range qc.tableIndex {
429		names = append(names, name)
430	}
431	sort.Slice(names, func(i, j int) bool { return names[i] < names[j] })
432	for _, name := range names {
433		qc.tables = append(qc.tables, qc.tableIndex[name])
434	}
435
436	return qc, nil
437}
438
439func (d *database) evalSelect(sel spansql.Select, qc *queryContext) (si *selIter, evalErr error) {
440	var ri rowIter = &nullIter{}
441	ec := evalContext{
442		params: qc.params,
443	}
444
445	// First stage is to identify the data source.
446	// If there's a FROM then that names a table to use.
447	if len(sel.From) > 1 {
448		return nil, fmt.Errorf("selecting with more than one FROM clause not yet supported")
449	}
450	if len(sel.From) == 1 {
451		var err error
452		ec, ri, err = d.evalSelectFrom(qc, ec, sel.From[0])
453		if err != nil {
454			return nil, err
455		}
456	}
457
458	// Apply WHERE.
459	if sel.Where != nil {
460		ri = whereIter{
461			ri:    ri,
462			ec:    ec,
463			where: sel.Where,
464		}
465	}
466
467	// Load aliases visible to any future iterators,
468	// including GROUP BY and ORDER BY. These are not visible to the WHERE clause.
469	ec.aliases = make(map[spansql.ID]spansql.Expr)
470	for i, alias := range sel.ListAliases {
471		ec.aliases[alias] = sel.List[i]
472	}
473	// TODO: Add aliases for "1", "2", etc.
474
475	// Apply GROUP BY.
476	// This only reorders rows to group rows together;
477	// aggregation happens next.
478	var rowGroups [][2]int // Sequence of half-open intervals of row numbers.
479	if len(sel.GroupBy) > 0 {
480		raw, err := toRawIter(ri)
481		if err != nil {
482			return nil, err
483		}
484		keys := make([][]interface{}, 0, len(raw.rows))
485		for _, row := range raw.rows {
486			// Evaluate sort key for this row.
487			ec.row = row
488			key, err := ec.evalExprList(sel.GroupBy)
489			if err != nil {
490				return nil, err
491			}
492			keys = append(keys, key)
493		}
494
495		// Reorder rows base on the evaluated keys.
496		ers := externalRowSorter{rows: raw.rows, keys: keys}
497		sort.Sort(ers)
498		raw.rows = ers.rows
499
500		// Record groups as a sequence of row intervals.
501		// Each group is a run of the same keys.
502		start := 0
503		for i := 1; i < len(keys); i++ {
504			if compareValLists(keys[i-1], keys[i], nil) == 0 {
505				continue
506			}
507			rowGroups = append(rowGroups, [2]int{start, i})
508			start = i
509		}
510		if len(keys) > 0 {
511			rowGroups = append(rowGroups, [2]int{start, len(keys)})
512		}
513
514		// Clear aliases, since they aren't visible elsewhere.
515		ec.aliases = nil
516
517		ri = raw
518	}
519
520	// Handle aggregation.
521	// TODO: Support more than one aggregation function; does Spanner support that?
522	aggI := -1
523	for i, e := range sel.List {
524		// Supported aggregate funcs have exactly one arg.
525		f, ok := e.(spansql.Func)
526		if !ok || len(f.Args) != 1 {
527			continue
528		}
529		_, ok = aggregateFuncs[f.Name]
530		if !ok {
531			continue
532		}
533		if aggI > -1 {
534			return nil, fmt.Errorf("only one aggregate function is supported")
535		}
536		aggI = i
537	}
538	if aggI > -1 {
539		raw, err := toRawIter(ri)
540		if err != nil {
541			return nil, err
542		}
543		if len(sel.GroupBy) == 0 {
544			// No grouping, so aggregation applies to the entire table (e.g. COUNT(*)).
545			// This may result in a [0,0) entry for empty inputs.
546			rowGroups = [][2]int{{0, len(raw.rows)}}
547		}
548		fexpr := sel.List[aggI].(spansql.Func)
549		fn := aggregateFuncs[fexpr.Name]
550		starArg := fexpr.Args[0] == spansql.Star
551		if starArg && !fn.AcceptStar {
552			return nil, fmt.Errorf("aggregate function %s does not accept * as an argument", fexpr.Name)
553		}
554		var argType spansql.Type
555		if !starArg {
556			ci, err := ec.colInfo(fexpr.Args[0])
557			if err != nil {
558				return nil, fmt.Errorf("evaluating aggregate function %s arg type: %v", fexpr.Name, err)
559			}
560			argType = ci.Type
561		}
562
563		// Prepare output.
564		rawOut := &rawIter{
565			// Same as input columns, but also the aggregate value.
566			// Add the colInfo for the aggregate at the end
567			// so we know the type.
568			// Make a copy for safety.
569			cols: append([]colInfo(nil), raw.cols...),
570		}
571
572		var aggType spansql.Type
573		for _, rg := range rowGroups {
574			// Compute aggregate value across this group.
575			var values []interface{}
576			for i := rg[0]; i < rg[1]; i++ {
577				ec.row = raw.rows[i]
578				if starArg {
579					// A non-NULL placeholder is sufficient for aggregation.
580					values = append(values, 1)
581				} else {
582					x, err := ec.evalExpr(fexpr.Args[0])
583					if err != nil {
584						return nil, err
585					}
586					values = append(values, x)
587				}
588			}
589			x, typ, err := fn.Eval(values, argType)
590			if err != nil {
591				return nil, err
592			}
593			aggType = typ
594
595			var outRow row
596			// Output for the row group is the first row of the group (arbitrary,
597			// but it should be representative), and the aggregate value.
598			// TODO: Should this exclude the aggregated expressions so they can't be selected?
599			// If the row group is empty then only the aggregation value is used;
600			// this covers things like COUNT(*) with no matching rows.
601			if rg[0] < len(raw.rows) {
602				repRow := raw.rows[rg[0]]
603				for i := range repRow {
604					outRow = append(outRow, repRow.copyDataElem(i))
605				}
606			} else {
607				// Fill with NULLs to keep the rows and colInfo aligned.
608				for i := 0; i < len(rawOut.cols); i++ {
609					outRow = append(outRow, nil)
610				}
611			}
612			outRow = append(outRow, x)
613			rawOut.rows = append(rawOut.rows, outRow)
614		}
615
616		if aggType == (spansql.Type{}) {
617			// Fallback; there might not be any groups.
618			// TODO: Should this be in aggregateFunc?
619			aggType = int64Type
620		}
621		rawOut.cols = append(raw.cols, colInfo{
622			Name:     spansql.ID(fexpr.SQL()), // TODO: this is a bit hokey, but it is output only
623			Type:     aggType,
624			AggIndex: aggI + 1,
625		})
626
627		ri = rawOut
628		ec.cols = rawOut.cols
629		sel.List[aggI] = aggSentinel{ // Mutate query so evalExpr in selIter picks out the new value.
630			Type:     aggType,
631			AggIndex: aggI + 1,
632		}
633	}
634
635	// TODO: Support table sampling.
636
637	// Apply SELECT list.
638	var colInfos []colInfo
639	for i, e := range sel.List {
640		if e == spansql.Star {
641			colInfos = append(colInfos, ec.cols...)
642		} else {
643			ci, err := ec.colInfo(e)
644			if err != nil {
645				return nil, err
646			}
647			if len(sel.ListAliases) > 0 {
648				alias := sel.ListAliases[i]
649				if alias != "" {
650					ci.Name = alias
651				}
652			}
653			// TODO: deal with ci.Name == ""?
654			colInfos = append(colInfos, ci)
655		}
656	}
657
658	return &selIter{
659		ri:   ri,
660		ec:   ec,
661		cis:  colInfos,
662		list: sel.List,
663
664		distinct: sel.Distinct, // Apply DISTINCT.
665	}, nil
666}
667
668func (d *database) evalSelectFrom(qc *queryContext, ec evalContext, sf spansql.SelectFrom) (evalContext, rowIter, error) {
669	switch sf := sf.(type) {
670	default:
671		return ec, nil, fmt.Errorf("selecting with FROM clause of type %T not yet supported", sf)
672	case spansql.SelectFromTable:
673		t, ok := qc.tableIndex[sf.Table]
674		if !ok {
675			// This shouldn't be possible; the queryContext should have discovered missing tables already.
676			return ec, nil, fmt.Errorf("unknown table %q", sf.Table)
677		}
678		ti := &tableIter{t: t}
679		if sf.Alias != "" {
680			ti.alias = sf.Alias
681		} else {
682			// There is an implicit alias using the table name.
683			// https://cloud.google.com/spanner/docs/query-syntax#implicit_aliases
684			ti.alias = sf.Table
685		}
686		ec.cols = ti.Cols()
687		return ec, ti, nil
688	case spansql.SelectFromJoin:
689		// TODO: Avoid the toRawIter calls here by doing the RHS recursive evalSelectFrom in joinIter.Next on demand.
690
691		lhsEC, lhs, err := d.evalSelectFrom(qc, ec, sf.LHS)
692		if err != nil {
693			return ec, nil, err
694		}
695		lhsRaw, err := toRawIter(lhs)
696		if err != nil {
697			return ec, nil, err
698		}
699
700		rhsEC, rhs, err := d.evalSelectFrom(qc, ec, sf.RHS)
701		if err != nil {
702			return ec, nil, err
703		}
704		rhsRaw, err := toRawIter(rhs)
705		if err != nil {
706			return ec, nil, err
707		}
708
709		ji, ec, err := newJoinIter(lhsRaw, rhsRaw, lhsEC, rhsEC, sf)
710		if err != nil {
711			return ec, nil, err
712		}
713		return ec, ji, nil
714	case spansql.SelectFromUnnest:
715		// TODO: Do all relevant types flow through here? Path expressions might be tricky here.
716		col, err := ec.colInfo(sf.Expr)
717		if err != nil {
718			return ec, nil, fmt.Errorf("evaluating type of UNNEST arg: %v", err)
719		}
720		if !col.Type.Array {
721			return ec, nil, fmt.Errorf("type of UNNEST arg is non-array %s", col.Type.SQL())
722		}
723		// The output of this UNNEST is the non-array version.
724		col.Name = sf.Alias // may be empty
725		col.Type.Array = false
726
727		// Evaluate the expression, and yield a virtual table with one column.
728		e, err := ec.evalExpr(sf.Expr)
729		if err != nil {
730			return ec, nil, fmt.Errorf("evaluating UNNEST arg: %v", err)
731		}
732		arr, ok := e.([]interface{})
733		if !ok {
734			return ec, nil, fmt.Errorf("evaluating UNNEST arg gave %t, want array", e)
735		}
736		var rows []row
737		for _, v := range arr {
738			rows = append(rows, row{v})
739		}
740
741		ri := &rawIter{
742			cols: []colInfo{col},
743			rows: rows,
744		}
745		ec.cols = ri.cols
746		return ec, ri, nil
747	}
748}
749
750func newJoinIter(lhs, rhs *rawIter, lhsEC, rhsEC evalContext, sfj spansql.SelectFromJoin) (*joinIter, evalContext, error) {
751	if sfj.On != nil && len(sfj.Using) > 0 {
752		return nil, evalContext{}, fmt.Errorf("JOIN may not have both ON and USING clauses")
753	}
754	if sfj.On == nil && len(sfj.Using) == 0 && sfj.Type != spansql.CrossJoin {
755		// TODO: This isn't correct for joining against a non-table.
756		return nil, evalContext{}, fmt.Errorf("non-CROSS JOIN must have ON or USING clause")
757	}
758
759	// Start with the context from the LHS (aliases and params should be the same on both sides).
760	ji := &joinIter{
761		jt: sfj.Type,
762		ec: lhsEC,
763
764		primary:       lhs,
765		secondaryOrig: rhs,
766
767		primaryOffset:   0,
768		secondaryOffset: len(lhsEC.cols),
769	}
770	switch ji.jt {
771	case spansql.LeftJoin:
772		ji.nullPad = true
773	case spansql.RightJoin:
774		ji.nullPad = true
775		// Primary is RHS.
776		ji.ec = rhsEC
777		ji.primary, ji.secondaryOrig = rhs, lhs
778		ji.primaryOffset, ji.secondaryOffset = len(rhsEC.cols), 0
779	case spansql.FullJoin:
780		// FULL JOIN is implemented as a LEFT JOIN with tracking for which rows of the RHS
781		// have been used. Then, at the end of the iteration, the unused RHS rows are emitted.
782		ji.nullPad = true
783		ji.used = make([]bool, 0, 10) // arbitrary preallocation
784	}
785	ji.ec.cols, ji.ec.row = nil, nil
786
787	// Construct a merged evalContext, and prepare the join condition evaluation.
788	// TODO: Remove ambiguous names here? Or catch them when evaluated?
789	// TODO: aliases might need work?
790	if len(sfj.Using) == 0 {
791		ji.prepNonUsing(sfj.On, lhsEC, rhsEC)
792	} else {
793		if err := ji.prepUsing(sfj.Using, lhsEC, rhsEC, ji.jt == spansql.RightJoin); err != nil {
794			return nil, evalContext{}, err
795		}
796	}
797
798	return ji, ji.ec, nil
799}
800
801// prepNonUsing configures the joinIter to evaluate with an ON clause or no join clause.
802// The arg is nil in the latter case.
803func (ji *joinIter) prepNonUsing(on spansql.BoolExpr, lhsEC, rhsEC evalContext) {
804	// Having ON or no clause results in the full set of columns from both sides.
805	// Force a copy.
806	ji.ec.cols = append(ji.ec.cols, lhsEC.cols...)
807	ji.ec.cols = append(ji.ec.cols, rhsEC.cols...)
808	ji.ec.row = make(row, len(ji.ec.cols))
809
810	ji.cond = func(primary, secondary row) (bool, error) {
811		copy(ji.ec.row[ji.primaryOffset:], primary)
812		copy(ji.ec.row[ji.secondaryOffset:], secondary)
813		if on == nil {
814			// No condition; all rows match.
815			return true, nil
816		}
817		b, err := ji.ec.evalBoolExpr(on)
818		if err != nil {
819			return false, err
820		}
821		return b != nil && *b, nil
822	}
823	ji.zero = func(primary, secondary row) {
824		for i := range ji.ec.row {
825			ji.ec.row[i] = nil
826		}
827		copy(ji.ec.row[ji.primaryOffset:], primary)
828		copy(ji.ec.row[ji.secondaryOffset:], secondary)
829	}
830}
831
832func (ji *joinIter) prepUsing(using []spansql.ID, lhsEC, rhsEC evalContext, flipped bool) error {
833	// Having a USING clause results in the set of named columns once,
834	// followed by the unnamed columns from both sides.
835
836	// lhsUsing is the column indexes in the LHS that the USING clause references.
837	// rhsUsing is similar.
838	// lhsNotUsing/rhsNotUsing are the complement.
839	var lhsUsing, rhsUsing []int
840	var lhsNotUsing, rhsNotUsing []int
841	// lhsUsed, rhsUsed are the set of column indexes in lhsUsing/rhsUsing.
842	lhsUsed, rhsUsed := make(map[int]bool), make(map[int]bool)
843	for _, id := range using {
844		lhsi, err := lhsEC.resolveColumnIndex(id)
845		if err != nil {
846			return err
847		}
848		lhsUsing = append(lhsUsing, lhsi)
849		lhsUsed[lhsi] = true
850
851		rhsi, err := rhsEC.resolveColumnIndex(id)
852		if err != nil {
853			return err
854		}
855		rhsUsing = append(rhsUsing, rhsi)
856		rhsUsed[rhsi] = true
857
858		// TODO: Should this hide or merge column aliases?
859		ji.ec.cols = append(ji.ec.cols, lhsEC.cols[lhsi])
860	}
861	for i, col := range lhsEC.cols {
862		if !lhsUsed[i] {
863			ji.ec.cols = append(ji.ec.cols, col)
864			lhsNotUsing = append(lhsNotUsing, i)
865		}
866	}
867	for i, col := range rhsEC.cols {
868		if !rhsUsed[i] {
869			ji.ec.cols = append(ji.ec.cols, col)
870			rhsNotUsing = append(rhsNotUsing, i)
871		}
872	}
873	ji.ec.row = make(row, len(ji.ec.cols))
874
875	primaryUsing, secondaryUsing := lhsUsing, rhsUsing
876	if flipped {
877		primaryUsing, secondaryUsing = secondaryUsing, primaryUsing
878	}
879
880	orNil := func(r row, i int) interface{} {
881		if r == nil {
882			return nil
883		}
884		return r[i]
885	}
886	// populate writes the data to ji.ec.row in the correct positions.
887	populate := func(primary, secondary row) { // either may be nil
888		j := 0
889		if primary != nil {
890			for _, pi := range primaryUsing {
891				ji.ec.row[j] = primary[pi]
892				j++
893			}
894		} else {
895			for _, si := range secondaryUsing {
896				ji.ec.row[j] = secondary[si]
897				j++
898			}
899		}
900		lhs, rhs := primary, secondary
901		if flipped {
902			rhs, lhs = lhs, rhs
903		}
904		for _, i := range lhsNotUsing {
905			ji.ec.row[j] = orNil(lhs, i)
906			j++
907		}
908		for _, i := range rhsNotUsing {
909			ji.ec.row[j] = orNil(rhs, i)
910			j++
911		}
912	}
913	ji.cond = func(primary, secondary row) (bool, error) {
914		for i, pi := range primaryUsing {
915			si := secondaryUsing[i]
916			if compareVals(primary[pi], secondary[si]) != 0 {
917				return false, nil
918			}
919		}
920		populate(primary, secondary)
921		return true, nil
922	}
923	ji.zero = func(primary, secondary row) {
924		populate(primary, secondary)
925	}
926	return nil
927}
928
929type joinIter struct {
930	jt spansql.JoinType
931	ec evalContext // combined context
932
933	// The "primary" is scanned (consumed), but the secondary is cloned for each primary row.
934	// Most join types have primary==LHS; a RIGHT JOIN is the exception.
935	primary, secondaryOrig *rawIter
936
937	// The offsets into ec.row that the primary/secondary rows should appear
938	// in the final output. Not used when there's a USING clause.
939	primaryOffset, secondaryOffset int
940	// nullPad is whether primary rows without matching secondary rows
941	// should be yielded with null padding (e.g. OUTER JOINs).
942	nullPad bool
943
944	primaryRow    row      // current row from primary, or nil if it is time to advance
945	secondary     *rawIter // current clone of secondary
946	secondaryRead int      // number of rows already read from secondary
947	any           bool     // true if any secondary rows have matched primaryRow
948
949	// cond reports whether the primary and secondary rows "join" (e.g. the ON clause is true).
950	// It populates ec.row with the output.
951	cond func(primary, secondary row) (bool, error)
952	// zero populates ec.row with the primary or secondary row data (either of which may be nil),
953	// and sets the remainder to NULL.
954	// This is used when nullPad is true and a primary or secondary row doesn't match.
955	zero func(primary, secondary row)
956
957	// For FULL JOIN, this tracks the secondary rows that have been used.
958	// It is non-nil when being used.
959	used        []bool
960	zeroUnused  bool // set when emitting unused secondary rows
961	unusedIndex int  // next index of used to check
962}
963
964func (ji *joinIter) Cols() []colInfo { return ji.ec.cols }
965
966func (ji *joinIter) nextPrimary() error {
967	var err error
968	ji.primaryRow, err = ji.primary.Next()
969	if err != nil {
970		return err
971	}
972	ji.secondary = ji.secondaryOrig.clone()
973	ji.secondaryRead = 0
974	ji.any = false
975	return nil
976}
977
978func (ji *joinIter) Next() (row, error) {
979	if ji.primaryRow == nil && !ji.zeroUnused {
980		err := ji.nextPrimary()
981		if err == io.EOF && ji.used != nil {
982			// Drop down to emitting unused secondary rows.
983			ji.zeroUnused = true
984			ji.secondary = nil
985			goto scanJiUsed
986		}
987		if err != nil {
988			return nil, err
989		}
990	}
991scanJiUsed:
992	if ji.zeroUnused {
993		if ji.secondary == nil {
994			ji.secondary = ji.secondaryOrig.clone()
995			ji.secondaryRead = 0
996		}
997		for ji.unusedIndex < len(ji.used) && ji.used[ji.unusedIndex] {
998			ji.unusedIndex++
999		}
1000		if ji.unusedIndex >= len(ji.used) || ji.secondaryRead >= len(ji.used) {
1001			// Truly finished.
1002			return nil, io.EOF
1003		}
1004		var secondaryRow row
1005		for ji.secondaryRead <= ji.unusedIndex {
1006			var err error
1007			secondaryRow, err = ji.secondary.Next()
1008			if err != nil {
1009				return nil, err
1010			}
1011			ji.secondaryRead++
1012		}
1013		ji.zero(nil, secondaryRow)
1014		return ji.ec.row, nil
1015	}
1016
1017	for {
1018		secondaryRow, err := ji.secondary.Next()
1019		if err == io.EOF {
1020			// Finished the current primary row.
1021
1022			if !ji.any && ji.nullPad {
1023				ji.zero(ji.primaryRow, nil)
1024				ji.primaryRow = nil
1025				return ji.ec.row, nil
1026			}
1027
1028			// Advance to next one.
1029			err := ji.nextPrimary()
1030			if err == io.EOF && ji.used != nil {
1031				ji.zeroUnused = true
1032				ji.secondary = nil
1033				goto scanJiUsed
1034			}
1035			if err != nil {
1036				return nil, err
1037			}
1038			continue
1039		}
1040		if err != nil {
1041			return nil, err
1042		}
1043		ji.secondaryRead++
1044		if ji.used != nil {
1045			for len(ji.used) < ji.secondaryRead {
1046				ji.used = append(ji.used, false)
1047			}
1048		}
1049
1050		// We have a pair of rows to consider.
1051		match, err := ji.cond(ji.primaryRow, secondaryRow)
1052		if err != nil {
1053			return nil, err
1054		}
1055		if !match {
1056			continue
1057		}
1058		ji.any = true
1059		if ji.used != nil {
1060			// Make a note that we used this secondary row.
1061			ji.used[ji.secondaryRead-1] = true
1062		}
1063		return ji.ec.row, nil
1064	}
1065}
1066
1067func evalSelectOrder(si *selIter, aux []spansql.Expr) (rows []row, keys [][]interface{}, err error) {
1068	// This is like toRawIter except it also evaluates the auxiliary expressions for ORDER BY.
1069	for {
1070		r, err := si.Next()
1071		if err == io.EOF {
1072			break
1073		} else if err != nil {
1074			return nil, nil, err
1075		}
1076		key, err := si.ec.evalExprList(aux)
1077		if err != nil {
1078			return nil, nil, err
1079		}
1080
1081		rows = append(rows, r.copyAllData())
1082		keys = append(keys, key)
1083	}
1084	return
1085}
1086
1087// externalRowSorter implements sort.Interface for a slice of rows
1088// with an external sort key.
1089type externalRowSorter struct {
1090	rows []row
1091	keys [][]interface{}
1092	desc []bool // may be nil
1093}
1094
1095func (ers externalRowSorter) Len() int { return len(ers.rows) }
1096func (ers externalRowSorter) Less(i, j int) bool {
1097	return compareValLists(ers.keys[i], ers.keys[j], ers.desc) < 0
1098}
1099func (ers externalRowSorter) Swap(i, j int) {
1100	ers.rows[i], ers.rows[j] = ers.rows[j], ers.rows[i]
1101	ers.keys[i], ers.keys[j] = ers.keys[j], ers.keys[i]
1102}
1103