1/*
2Copyright 2020 Google LLC
3
4Licensed under the Apache License, Version 2.0 (the "License");
5you may not use this file except in compliance with the License.
6You may obtain a copy of the License at
7
8    http://www.apache.org/licenses/LICENSE-2.0
9
10Unless required by applicable law or agreed to in writing, software
11distributed under the License is distributed on an "AS IS" BASIS,
12WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13See the License for the specific language governing permissions and
14limitations under the License.
15*/
16
17package spannertest
18
19import (
20	"fmt"
21	"io"
22	"sort"
23
24	"cloud.google.com/go/spanner/spansql"
25)
26
27/*
28There's several ways to conceptualise SQL queries. The simplest, and what
29we implement here, is a series of pipelines that transform the data, whether
30pulling from a table (FROM tbl), filtering (WHERE expr), re-ordering (ORDER BY expr)
31or other transformations.
32
33The order of operations among those supported by Cloud Spanner is
34	FROM + JOIN + set ops [TODO: set ops]
35	WHERE
36	GROUP BY
37	aggregation
38	HAVING [TODO]
39	SELECT
40	DISTINCT
41	ORDER BY
42	OFFSET
43	LIMIT
44*/
45
46// rowIter represents some iteration over rows of data.
47// It is returned by reads and queries.
48type rowIter interface {
49	// Cols returns the metadata about the returned data.
50	Cols() []colInfo
51
52	// Next returns the next row.
53	// If done, it returns (nil, io.EOF).
54	Next() (row, error)
55}
56
57// aggSentinel is a synthetic expression that refers to an aggregated value.
58// It is transient only; it is never stored and only used during evaluation.
59type aggSentinel struct {
60	spansql.Expr
61	Type     spansql.Type
62	AggIndex int // Index+1 of SELECT list.
63}
64
65// nullIter is a rowIter that returns one empty row only.
66// This is used for queries without a table.
67type nullIter struct {
68	done bool
69}
70
71func (ni *nullIter) Cols() []colInfo { return nil }
72func (ni *nullIter) Next() (row, error) {
73	if ni.done {
74		return nil, io.EOF
75	}
76	ni.done = true
77	return nil, nil
78}
79
80// tableIter is a rowIter that walks a table.
81// It assumes the table is locked for the duration.
82type tableIter struct {
83	t        *table
84	rowIndex int // index of next row to return
85
86	alias spansql.ID // if non-empty, "AS <alias>"
87}
88
89func (ti *tableIter) Cols() []colInfo {
90	// Build colInfo in the original column order.
91	cis := make([]colInfo, len(ti.t.cols))
92	for _, ci := range ti.t.cols {
93		if ti.alias != "" {
94			ci.Alias = spansql.PathExp{ti.alias, ci.Name}
95		}
96		cis[ti.t.origIndex[ci.Name]] = ci
97	}
98	return cis
99}
100
101func (ti *tableIter) Next() (row, error) {
102	if ti.rowIndex >= len(ti.t.rows) {
103		return nil, io.EOF
104	}
105	r := ti.t.rows[ti.rowIndex]
106	ti.rowIndex++
107
108	// Build output row in the original column order.
109	res := make(row, len(r))
110	for i, ci := range ti.t.cols {
111		res[ti.t.origIndex[ci.Name]] = r[i]
112	}
113
114	return res, nil
115}
116
117// rawIter is a rowIter with fixed data.
118type rawIter struct {
119	// cols is the metadata about the returned data.
120	cols []colInfo
121
122	// rows holds the result data itself.
123	rows []row
124}
125
126func (raw *rawIter) Cols() []colInfo { return raw.cols }
127func (raw *rawIter) Next() (row, error) {
128	if len(raw.rows) == 0 {
129		return nil, io.EOF
130	}
131	res := raw.rows[0]
132	raw.rows = raw.rows[1:]
133	return res, nil
134}
135
136func (raw *rawIter) add(src row, colIndexes []int) {
137	raw.rows = append(raw.rows, src.copyData(colIndexes))
138}
139
140// clone makes a shallow copy.
141func (raw *rawIter) clone() *rawIter {
142	return &rawIter{cols: raw.cols, rows: raw.rows}
143}
144
145func toRawIter(ri rowIter) (*rawIter, error) {
146	if raw, ok := ri.(*rawIter); ok {
147		return raw, nil
148	}
149	raw := &rawIter{cols: ri.Cols()}
150	for {
151		row, err := ri.Next()
152		if err == io.EOF {
153			break
154		} else if err != nil {
155			return nil, err
156		}
157		raw.rows = append(raw.rows, row.copyAllData())
158	}
159	return raw, nil
160}
161
162// whereIter applies a WHERE clause.
163type whereIter struct {
164	ri    rowIter
165	ec    evalContext
166	where spansql.BoolExpr
167}
168
169func (wi whereIter) Cols() []colInfo { return wi.ri.Cols() }
170func (wi whereIter) Next() (row, error) {
171	for {
172		row, err := wi.ri.Next()
173		if err != nil {
174			return nil, err
175		}
176		wi.ec.row = row
177
178		b, err := wi.ec.evalBoolExpr(wi.where)
179		if err != nil {
180			return nil, err
181		}
182		if b != nil && *b {
183			return row, nil
184		}
185	}
186}
187
188// selIter applies a SELECT list.
189type selIter struct {
190	ri   rowIter
191	ec   evalContext
192	cis  []colInfo
193	list []spansql.Expr
194
195	distinct bool // whether this is a SELECT DISTINCT
196	seen     []row
197}
198
199func (si *selIter) Cols() []colInfo { return si.cis }
200func (si *selIter) Next() (row, error) {
201	for {
202		r, err := si.next()
203		if err != nil {
204			return nil, err
205		}
206		if si.distinct && !si.keep(r) {
207			continue
208		}
209		return r, nil
210	}
211}
212
213// next retrieves the next row for the SELECT and evaluates its expression list.
214func (si *selIter) next() (row, error) {
215	r, err := si.ri.Next()
216	if err != nil {
217		return nil, err
218	}
219	si.ec.row = r
220
221	var out row
222	for _, e := range si.list {
223		if e == spansql.Star {
224			out = append(out, r...)
225		} else {
226			v, err := si.ec.evalExpr(e)
227			if err != nil {
228				return nil, err
229			}
230			out = append(out, v)
231		}
232	}
233	return out, nil
234}
235
236func (si *selIter) keep(r row) bool {
237	// This is hilariously inefficient; O(N^2) in the number of returned rows.
238	// Some sort of hashing could be done to deduplicate instead.
239	// This also breaks on array/struct types.
240	for _, prev := range si.seen {
241		if rowEqual(prev, r) {
242			return false
243		}
244	}
245	si.seen = append(si.seen, r)
246	return true
247}
248
249// offsetIter applies an OFFSET clause.
250type offsetIter struct {
251	ri   rowIter
252	skip int64
253}
254
255func (oi *offsetIter) Cols() []colInfo { return oi.ri.Cols() }
256func (oi *offsetIter) Next() (row, error) {
257	for oi.skip > 0 {
258		_, err := oi.ri.Next()
259		if err != nil {
260			return nil, err
261		}
262		oi.skip--
263	}
264	row, err := oi.ri.Next()
265	if err != nil {
266		return nil, err
267	}
268	return row, nil
269}
270
271// limitIter applies a LIMIT clause.
272type limitIter struct {
273	ri  rowIter
274	rem int64
275}
276
277func (li *limitIter) Cols() []colInfo { return li.ri.Cols() }
278func (li *limitIter) Next() (row, error) {
279	if li.rem <= 0 {
280		return nil, io.EOF
281	}
282	row, err := li.ri.Next()
283	if err != nil {
284		return nil, err
285	}
286	li.rem--
287	return row, nil
288}
289
290type queryParam struct {
291	Value interface{} // internal representation
292	Type  spansql.Type
293}
294
295type queryParams map[string]queryParam // TODO: change key to spansql.Param?
296
297type queryContext struct {
298	params queryParams
299
300	tables     []*table // sorted by name
301	tableIndex map[spansql.ID]*table
302	locks      int
303}
304
305func (qc *queryContext) Lock() {
306	// Take locks in name order.
307	for _, t := range qc.tables {
308		t.mu.Lock()
309		qc.locks++
310	}
311}
312
313func (qc *queryContext) Unlock() {
314	for _, t := range qc.tables {
315		t.mu.Unlock()
316		qc.locks--
317	}
318}
319
320func (d *database) Query(q spansql.Query, params queryParams) (ri rowIter, err error) {
321	// Figure out the context of the query and take any required locks.
322	qc, err := d.queryContext(q, params)
323	if err != nil {
324		return nil, err
325	}
326	qc.Lock()
327	// On the way out, if there were locks taken, flatten the output
328	// and release the locks.
329	if qc.locks > 0 {
330		defer func() {
331			if err == nil {
332				ri, err = toRawIter(ri)
333			}
334			qc.Unlock()
335		}()
336	}
337
338	// Prepare auxiliary expressions to evaluate for ORDER BY.
339	var aux []spansql.Expr
340	var desc []bool
341	for _, o := range q.Order {
342		aux = append(aux, o.Expr)
343		desc = append(desc, o.Desc)
344	}
345
346	si, err := d.evalSelect(q.Select, qc)
347	if err != nil {
348		return nil, err
349	}
350	ri = si
351
352	// Apply ORDER BY.
353	if len(q.Order) > 0 {
354		// Evaluate the selIter completely, and sort the rows by the auxiliary expressions.
355		rows, keys, err := evalSelectOrder(si, aux)
356		if err != nil {
357			return nil, err
358		}
359		sort.Sort(externalRowSorter{rows: rows, keys: keys, desc: desc})
360		ri = &rawIter{cols: si.cis, rows: rows}
361	}
362
363	// Apply LIMIT, OFFSET.
364	if q.Limit != nil {
365		if q.Offset != nil {
366			off, err := evalLiteralOrParam(q.Offset, params)
367			if err != nil {
368				return nil, err
369			}
370			ri = &offsetIter{ri: ri, skip: off}
371		}
372
373		lim, err := evalLiteralOrParam(q.Limit, params)
374		if err != nil {
375			return nil, err
376		}
377		ri = &limitIter{ri: ri, rem: lim}
378	}
379
380	return ri, nil
381}
382
383func (d *database) queryContext(q spansql.Query, params queryParams) (*queryContext, error) {
384	qc := &queryContext{
385		params: params,
386	}
387
388	// Look for any mentioned tables and add them to qc.tableIndex.
389	addTable := func(name spansql.ID) error {
390		if _, ok := qc.tableIndex[name]; ok {
391			return nil // Already found this table.
392		}
393		t, err := d.table(name)
394		if err != nil {
395			return err
396		}
397		if qc.tableIndex == nil {
398			qc.tableIndex = make(map[spansql.ID]*table)
399		}
400		qc.tableIndex[name] = t
401		return nil
402	}
403	var findTables func(sf spansql.SelectFrom) error
404	findTables = func(sf spansql.SelectFrom) error {
405		switch sf := sf.(type) {
406		default:
407			return fmt.Errorf("can't prepare query context for SelectFrom of type %T", sf)
408		case spansql.SelectFromTable:
409			return addTable(sf.Table)
410		case spansql.SelectFromJoin:
411			if err := findTables(sf.LHS); err != nil {
412				return err
413			}
414			return findTables(sf.RHS)
415		case spansql.SelectFromUnnest:
416			// TODO: if array paths get supported, this will need more work.
417			return nil
418		}
419	}
420	for _, sf := range q.Select.From {
421		if err := findTables(sf); err != nil {
422			return nil, err
423		}
424	}
425
426	// Build qc.tables in name order so we can take locks in a well-defined order.
427	var names []spansql.ID
428	for name := range qc.tableIndex {
429		names = append(names, name)
430	}
431	sort.Slice(names, func(i, j int) bool { return names[i] < names[j] })
432	for _, name := range names {
433		qc.tables = append(qc.tables, qc.tableIndex[name])
434	}
435
436	return qc, nil
437}
438
439func (d *database) evalSelect(sel spansql.Select, qc *queryContext) (si *selIter, evalErr error) {
440	var ri rowIter = &nullIter{}
441	ec := evalContext{
442		params: qc.params,
443	}
444
445	// First stage is to identify the data source.
446	// If there's a FROM then that names a table to use.
447	if len(sel.From) > 1 {
448		return nil, fmt.Errorf("selecting with more than one FROM clause not yet supported")
449	}
450	if len(sel.From) == 1 {
451		var err error
452		ec, ri, err = d.evalSelectFrom(qc, ec, sel.From[0])
453		if err != nil {
454			return nil, err
455		}
456	}
457
458	// Apply WHERE.
459	if sel.Where != nil {
460		ri = whereIter{
461			ri:    ri,
462			ec:    ec,
463			where: sel.Where,
464		}
465	}
466
467	// Load aliases visible to any future iterators,
468	// including GROUP BY and ORDER BY. These are not visible to the WHERE clause.
469	ec.aliases = make(map[spansql.ID]spansql.Expr)
470	for i, alias := range sel.ListAliases {
471		ec.aliases[alias] = sel.List[i]
472	}
473	// TODO: Add aliases for "1", "2", etc.
474
475	// Apply GROUP BY.
476	// This only reorders rows to group rows together;
477	// aggregation happens next.
478	var rowGroups [][2]int // Sequence of half-open intervals of row numbers.
479	if len(sel.GroupBy) > 0 {
480		raw, err := toRawIter(ri)
481		if err != nil {
482			return nil, err
483		}
484		keys := make([][]interface{}, 0, len(raw.rows))
485		for _, row := range raw.rows {
486			// Evaluate sort key for this row.
487			ec.row = row
488			key, err := ec.evalExprList(sel.GroupBy)
489			if err != nil {
490				return nil, err
491			}
492			keys = append(keys, key)
493		}
494
495		// Reorder rows base on the evaluated keys.
496		ers := externalRowSorter{rows: raw.rows, keys: keys}
497		sort.Sort(ers)
498		raw.rows = ers.rows
499
500		// Record groups as a sequence of row intervals.
501		// Each group is a run of the same keys.
502		start := 0
503		for i := 1; i < len(keys); i++ {
504			if compareValLists(keys[i-1], keys[i], nil) == 0 {
505				continue
506			}
507			rowGroups = append(rowGroups, [2]int{start, i})
508			start = i
509		}
510		if len(keys) > 0 {
511			rowGroups = append(rowGroups, [2]int{start, len(keys)})
512		}
513
514		// Clear aliases, since they aren't visible elsewhere.
515		ec.aliases = nil
516
517		ri = raw
518	}
519
520	// Handle aggregation.
521	var aggI []int
522	for i, e := range sel.List {
523		// Supported aggregate funcs have exactly one arg.
524		f, ok := e.(spansql.Func)
525		if !ok || len(f.Args) != 1 {
526			continue
527		}
528		_, ok = aggregateFuncs[f.Name]
529		if !ok {
530			continue
531		}
532		aggI = append(aggI, i)
533	}
534	if len(aggI) > 0 {
535		raw, err := toRawIter(ri)
536		if err != nil {
537			return nil, err
538		}
539		if len(sel.GroupBy) == 0 {
540			// No grouping, so aggregation applies to the entire table (e.g. COUNT(*)).
541			// This may result in a [0,0) entry for empty inputs.
542			rowGroups = [][2]int{{0, len(raw.rows)}}
543		}
544
545		// Prepare output.
546		rawOut := &rawIter{
547			// Same as input columns, but also the aggregate value.
548			// Add the colInfo for the aggregate at the end
549			// so we know the type.
550			// Make a copy for safety.
551			cols: append([]colInfo(nil), raw.cols...),
552		}
553
554		aggType := make([]*spansql.Type, len(aggI))
555		for _, rg := range rowGroups {
556			var outRow row
557			// Output for the row group is the first row of the group (arbitrary,
558			// but it should be representative), and the aggregate value.
559			// TODO: Should this exclude the aggregated expressions so they can't be selected?
560			// If the row group is empty then only the aggregation value is used;
561			// this covers things like COUNT(*) with no matching rows.
562			if rg[0] < len(raw.rows) {
563				repRow := raw.rows[rg[0]]
564				for i := range repRow {
565					outRow = append(outRow, repRow.copyDataElem(i))
566				}
567			} else {
568				// Fill with NULLs to keep the rows and colInfo aligned.
569				for i := 0; i < len(rawOut.cols); i++ {
570					outRow = append(outRow, nil)
571				}
572			}
573
574			for j, aggI := range aggI {
575				fexpr := sel.List[aggI].(spansql.Func)
576				fn := aggregateFuncs[fexpr.Name]
577				starArg := fexpr.Args[0] == spansql.Star
578				if starArg && !fn.AcceptStar {
579					return nil, fmt.Errorf("aggregate function %s does not accept * as an argument", fexpr.Name)
580				}
581				var argType spansql.Type
582				if !starArg {
583					ci, err := ec.colInfo(fexpr.Args[0])
584					if err != nil {
585						return nil, fmt.Errorf("evaluating aggregate function %s arg type: %v", fexpr.Name, err)
586					}
587					argType = ci.Type
588				}
589
590				// Compute aggregate value across this group.
591				var values []interface{}
592				for i := rg[0]; i < rg[1]; i++ {
593					ec.row = raw.rows[i]
594					if starArg {
595						// A non-NULL placeholder is sufficient for aggregation.
596						values = append(values, 1)
597					} else {
598						x, err := ec.evalExpr(fexpr.Args[0])
599						if err != nil {
600							return nil, err
601						}
602						values = append(values, x)
603					}
604				}
605				x, typ, err := fn.Eval(values, argType)
606				if err != nil {
607					return nil, err
608				}
609				aggType[j] = &typ
610
611				outRow = append(outRow, x)
612			}
613			rawOut.rows = append(rawOut.rows, outRow)
614		}
615
616		for j, aggI := range aggI {
617			fexpr := sel.List[aggI].(spansql.Func)
618			if aggType[j] == nil {
619				// Fallback; there might not be any groups.
620				// TODO: Should this be in aggregateFunc?
621				aggType[j] = &int64Type
622			}
623			rawOut.cols = append(rawOut.cols, colInfo{
624				Name:     spansql.ID(fexpr.SQL()), // TODO: this is a bit hokey, but it is output only
625				Type:     *aggType[j],
626				AggIndex: aggI + 1,
627			})
628			sel.List[aggI] = aggSentinel{ // Mutate query so evalExpr in selIter picks out the new value.
629				Type:     *aggType[j],
630				AggIndex: aggI + 1,
631			}
632		}
633		ri = rawOut
634		ec.cols = rawOut.cols
635	}
636
637	// TODO: Support table sampling.
638
639	// Apply SELECT list.
640	var colInfos []colInfo
641	for i, e := range sel.List {
642		if e == spansql.Star {
643			colInfos = append(colInfos, ec.cols...)
644		} else {
645			ci, err := ec.colInfo(e)
646			if err != nil {
647				return nil, err
648			}
649			if len(sel.ListAliases) > 0 {
650				alias := sel.ListAliases[i]
651				if alias != "" {
652					ci.Name = alias
653				}
654			}
655			// TODO: deal with ci.Name == ""?
656			colInfos = append(colInfos, ci)
657		}
658	}
659
660	return &selIter{
661		ri:   ri,
662		ec:   ec,
663		cis:  colInfos,
664		list: sel.List,
665
666		distinct: sel.Distinct, // Apply DISTINCT.
667	}, nil
668}
669
670func (d *database) evalSelectFrom(qc *queryContext, ec evalContext, sf spansql.SelectFrom) (evalContext, rowIter, error) {
671	switch sf := sf.(type) {
672	default:
673		return ec, nil, fmt.Errorf("selecting with FROM clause of type %T not yet supported", sf)
674	case spansql.SelectFromTable:
675		t, ok := qc.tableIndex[sf.Table]
676		if !ok {
677			// This shouldn't be possible; the queryContext should have discovered missing tables already.
678			return ec, nil, fmt.Errorf("unknown table %q", sf.Table)
679		}
680		ti := &tableIter{t: t}
681		if sf.Alias != "" {
682			ti.alias = sf.Alias
683		} else {
684			// There is an implicit alias using the table name.
685			// https://cloud.google.com/spanner/docs/query-syntax#implicit_aliases
686			ti.alias = sf.Table
687		}
688		ec.cols = ti.Cols()
689		return ec, ti, nil
690	case spansql.SelectFromJoin:
691		// TODO: Avoid the toRawIter calls here by doing the RHS recursive evalSelectFrom in joinIter.Next on demand.
692
693		lhsEC, lhs, err := d.evalSelectFrom(qc, ec, sf.LHS)
694		if err != nil {
695			return ec, nil, err
696		}
697		lhsRaw, err := toRawIter(lhs)
698		if err != nil {
699			return ec, nil, err
700		}
701
702		rhsEC, rhs, err := d.evalSelectFrom(qc, ec, sf.RHS)
703		if err != nil {
704			return ec, nil, err
705		}
706		rhsRaw, err := toRawIter(rhs)
707		if err != nil {
708			return ec, nil, err
709		}
710
711		ji, ec, err := newJoinIter(lhsRaw, rhsRaw, lhsEC, rhsEC, sf)
712		if err != nil {
713			return ec, nil, err
714		}
715		return ec, ji, nil
716	case spansql.SelectFromUnnest:
717		// TODO: Do all relevant types flow through here? Path expressions might be tricky here.
718		col, err := ec.colInfo(sf.Expr)
719		if err != nil {
720			return ec, nil, fmt.Errorf("evaluating type of UNNEST arg: %v", err)
721		}
722		if !col.Type.Array {
723			return ec, nil, fmt.Errorf("type of UNNEST arg is non-array %s", col.Type.SQL())
724		}
725		// The output of this UNNEST is the non-array version.
726		col.Name = sf.Alias // may be empty
727		col.Type.Array = false
728
729		// Evaluate the expression, and yield a virtual table with one column.
730		e, err := ec.evalExpr(sf.Expr)
731		if err != nil {
732			return ec, nil, fmt.Errorf("evaluating UNNEST arg: %v", err)
733		}
734		arr, ok := e.([]interface{})
735		if !ok {
736			return ec, nil, fmt.Errorf("evaluating UNNEST arg gave %t, want array", e)
737		}
738		var rows []row
739		for _, v := range arr {
740			rows = append(rows, row{v})
741		}
742
743		ri := &rawIter{
744			cols: []colInfo{col},
745			rows: rows,
746		}
747		ec.cols = ri.cols
748		return ec, ri, nil
749	}
750}
751
752func newJoinIter(lhs, rhs *rawIter, lhsEC, rhsEC evalContext, sfj spansql.SelectFromJoin) (*joinIter, evalContext, error) {
753	if sfj.On != nil && len(sfj.Using) > 0 {
754		return nil, evalContext{}, fmt.Errorf("JOIN may not have both ON and USING clauses")
755	}
756	if sfj.On == nil && len(sfj.Using) == 0 && sfj.Type != spansql.CrossJoin {
757		// TODO: This isn't correct for joining against a non-table.
758		return nil, evalContext{}, fmt.Errorf("non-CROSS JOIN must have ON or USING clause")
759	}
760
761	// Start with the context from the LHS (aliases and params should be the same on both sides).
762	ji := &joinIter{
763		jt: sfj.Type,
764		ec: lhsEC,
765
766		primary:       lhs,
767		secondaryOrig: rhs,
768
769		primaryOffset:   0,
770		secondaryOffset: len(lhsEC.cols),
771	}
772	switch ji.jt {
773	case spansql.LeftJoin:
774		ji.nullPad = true
775	case spansql.RightJoin:
776		ji.nullPad = true
777		// Primary is RHS.
778		ji.ec = rhsEC
779		ji.primary, ji.secondaryOrig = rhs, lhs
780		ji.primaryOffset, ji.secondaryOffset = len(rhsEC.cols), 0
781	case spansql.FullJoin:
782		// FULL JOIN is implemented as a LEFT JOIN with tracking for which rows of the RHS
783		// have been used. Then, at the end of the iteration, the unused RHS rows are emitted.
784		ji.nullPad = true
785		ji.used = make([]bool, 0, 10) // arbitrary preallocation
786	}
787	ji.ec.cols, ji.ec.row = nil, nil
788
789	// Construct a merged evalContext, and prepare the join condition evaluation.
790	// TODO: Remove ambiguous names here? Or catch them when evaluated?
791	// TODO: aliases might need work?
792	if len(sfj.Using) == 0 {
793		ji.prepNonUsing(sfj.On, lhsEC, rhsEC)
794	} else {
795		if err := ji.prepUsing(sfj.Using, lhsEC, rhsEC, ji.jt == spansql.RightJoin); err != nil {
796			return nil, evalContext{}, err
797		}
798	}
799
800	return ji, ji.ec, nil
801}
802
803// prepNonUsing configures the joinIter to evaluate with an ON clause or no join clause.
804// The arg is nil in the latter case.
805func (ji *joinIter) prepNonUsing(on spansql.BoolExpr, lhsEC, rhsEC evalContext) {
806	// Having ON or no clause results in the full set of columns from both sides.
807	// Force a copy.
808	ji.ec.cols = append(ji.ec.cols, lhsEC.cols...)
809	ji.ec.cols = append(ji.ec.cols, rhsEC.cols...)
810	ji.ec.row = make(row, len(ji.ec.cols))
811
812	ji.cond = func(primary, secondary row) (bool, error) {
813		copy(ji.ec.row[ji.primaryOffset:], primary)
814		copy(ji.ec.row[ji.secondaryOffset:], secondary)
815		if on == nil {
816			// No condition; all rows match.
817			return true, nil
818		}
819		b, err := ji.ec.evalBoolExpr(on)
820		if err != nil {
821			return false, err
822		}
823		return b != nil && *b, nil
824	}
825	ji.zero = func(primary, secondary row) {
826		for i := range ji.ec.row {
827			ji.ec.row[i] = nil
828		}
829		copy(ji.ec.row[ji.primaryOffset:], primary)
830		copy(ji.ec.row[ji.secondaryOffset:], secondary)
831	}
832}
833
834func (ji *joinIter) prepUsing(using []spansql.ID, lhsEC, rhsEC evalContext, flipped bool) error {
835	// Having a USING clause results in the set of named columns once,
836	// followed by the unnamed columns from both sides.
837
838	// lhsUsing is the column indexes in the LHS that the USING clause references.
839	// rhsUsing is similar.
840	// lhsNotUsing/rhsNotUsing are the complement.
841	var lhsUsing, rhsUsing []int
842	var lhsNotUsing, rhsNotUsing []int
843	// lhsUsed, rhsUsed are the set of column indexes in lhsUsing/rhsUsing.
844	lhsUsed, rhsUsed := make(map[int]bool), make(map[int]bool)
845	for _, id := range using {
846		lhsi, err := lhsEC.resolveColumnIndex(id)
847		if err != nil {
848			return err
849		}
850		lhsUsing = append(lhsUsing, lhsi)
851		lhsUsed[lhsi] = true
852
853		rhsi, err := rhsEC.resolveColumnIndex(id)
854		if err != nil {
855			return err
856		}
857		rhsUsing = append(rhsUsing, rhsi)
858		rhsUsed[rhsi] = true
859
860		// TODO: Should this hide or merge column aliases?
861		ji.ec.cols = append(ji.ec.cols, lhsEC.cols[lhsi])
862	}
863	for i, col := range lhsEC.cols {
864		if !lhsUsed[i] {
865			ji.ec.cols = append(ji.ec.cols, col)
866			lhsNotUsing = append(lhsNotUsing, i)
867		}
868	}
869	for i, col := range rhsEC.cols {
870		if !rhsUsed[i] {
871			ji.ec.cols = append(ji.ec.cols, col)
872			rhsNotUsing = append(rhsNotUsing, i)
873		}
874	}
875	ji.ec.row = make(row, len(ji.ec.cols))
876
877	primaryUsing, secondaryUsing := lhsUsing, rhsUsing
878	if flipped {
879		primaryUsing, secondaryUsing = secondaryUsing, primaryUsing
880	}
881
882	orNil := func(r row, i int) interface{} {
883		if r == nil {
884			return nil
885		}
886		return r[i]
887	}
888	// populate writes the data to ji.ec.row in the correct positions.
889	populate := func(primary, secondary row) { // either may be nil
890		j := 0
891		if primary != nil {
892			for _, pi := range primaryUsing {
893				ji.ec.row[j] = primary[pi]
894				j++
895			}
896		} else {
897			for _, si := range secondaryUsing {
898				ji.ec.row[j] = secondary[si]
899				j++
900			}
901		}
902		lhs, rhs := primary, secondary
903		if flipped {
904			rhs, lhs = lhs, rhs
905		}
906		for _, i := range lhsNotUsing {
907			ji.ec.row[j] = orNil(lhs, i)
908			j++
909		}
910		for _, i := range rhsNotUsing {
911			ji.ec.row[j] = orNil(rhs, i)
912			j++
913		}
914	}
915	ji.cond = func(primary, secondary row) (bool, error) {
916		for i, pi := range primaryUsing {
917			si := secondaryUsing[i]
918			if compareVals(primary[pi], secondary[si]) != 0 {
919				return false, nil
920			}
921		}
922		populate(primary, secondary)
923		return true, nil
924	}
925	ji.zero = func(primary, secondary row) {
926		populate(primary, secondary)
927	}
928	return nil
929}
930
931type joinIter struct {
932	jt spansql.JoinType
933	ec evalContext // combined context
934
935	// The "primary" is scanned (consumed), but the secondary is cloned for each primary row.
936	// Most join types have primary==LHS; a RIGHT JOIN is the exception.
937	primary, secondaryOrig *rawIter
938
939	// The offsets into ec.row that the primary/secondary rows should appear
940	// in the final output. Not used when there's a USING clause.
941	primaryOffset, secondaryOffset int
942	// nullPad is whether primary rows without matching secondary rows
943	// should be yielded with null padding (e.g. OUTER JOINs).
944	nullPad bool
945
946	primaryRow    row      // current row from primary, or nil if it is time to advance
947	secondary     *rawIter // current clone of secondary
948	secondaryRead int      // number of rows already read from secondary
949	any           bool     // true if any secondary rows have matched primaryRow
950
951	// cond reports whether the primary and secondary rows "join" (e.g. the ON clause is true).
952	// It populates ec.row with the output.
953	cond func(primary, secondary row) (bool, error)
954	// zero populates ec.row with the primary or secondary row data (either of which may be nil),
955	// and sets the remainder to NULL.
956	// This is used when nullPad is true and a primary or secondary row doesn't match.
957	zero func(primary, secondary row)
958
959	// For FULL JOIN, this tracks the secondary rows that have been used.
960	// It is non-nil when being used.
961	used        []bool
962	zeroUnused  bool // set when emitting unused secondary rows
963	unusedIndex int  // next index of used to check
964}
965
966func (ji *joinIter) Cols() []colInfo { return ji.ec.cols }
967
968func (ji *joinIter) nextPrimary() error {
969	var err error
970	ji.primaryRow, err = ji.primary.Next()
971	if err != nil {
972		return err
973	}
974	ji.secondary = ji.secondaryOrig.clone()
975	ji.secondaryRead = 0
976	ji.any = false
977	return nil
978}
979
980func (ji *joinIter) Next() (row, error) {
981	if ji.primaryRow == nil && !ji.zeroUnused {
982		err := ji.nextPrimary()
983		if err == io.EOF && ji.used != nil {
984			// Drop down to emitting unused secondary rows.
985			ji.zeroUnused = true
986			ji.secondary = nil
987			goto scanJiUsed
988		}
989		if err != nil {
990			return nil, err
991		}
992	}
993scanJiUsed:
994	if ji.zeroUnused {
995		if ji.secondary == nil {
996			ji.secondary = ji.secondaryOrig.clone()
997			ji.secondaryRead = 0
998		}
999		for ji.unusedIndex < len(ji.used) && ji.used[ji.unusedIndex] {
1000			ji.unusedIndex++
1001		}
1002		if ji.unusedIndex >= len(ji.used) || ji.secondaryRead >= len(ji.used) {
1003			// Truly finished.
1004			return nil, io.EOF
1005		}
1006		var secondaryRow row
1007		for ji.secondaryRead <= ji.unusedIndex {
1008			var err error
1009			secondaryRow, err = ji.secondary.Next()
1010			if err != nil {
1011				return nil, err
1012			}
1013			ji.secondaryRead++
1014		}
1015		ji.zero(nil, secondaryRow)
1016		return ji.ec.row, nil
1017	}
1018
1019	for {
1020		secondaryRow, err := ji.secondary.Next()
1021		if err == io.EOF {
1022			// Finished the current primary row.
1023
1024			if !ji.any && ji.nullPad {
1025				ji.zero(ji.primaryRow, nil)
1026				ji.primaryRow = nil
1027				return ji.ec.row, nil
1028			}
1029
1030			// Advance to next one.
1031			err := ji.nextPrimary()
1032			if err == io.EOF && ji.used != nil {
1033				ji.zeroUnused = true
1034				ji.secondary = nil
1035				goto scanJiUsed
1036			}
1037			if err != nil {
1038				return nil, err
1039			}
1040			continue
1041		}
1042		if err != nil {
1043			return nil, err
1044		}
1045		ji.secondaryRead++
1046		if ji.used != nil {
1047			for len(ji.used) < ji.secondaryRead {
1048				ji.used = append(ji.used, false)
1049			}
1050		}
1051
1052		// We have a pair of rows to consider.
1053		match, err := ji.cond(ji.primaryRow, secondaryRow)
1054		if err != nil {
1055			return nil, err
1056		}
1057		if !match {
1058			continue
1059		}
1060		ji.any = true
1061		if ji.used != nil {
1062			// Make a note that we used this secondary row.
1063			ji.used[ji.secondaryRead-1] = true
1064		}
1065		return ji.ec.row, nil
1066	}
1067}
1068
1069func evalSelectOrder(si *selIter, aux []spansql.Expr) (rows []row, keys [][]interface{}, err error) {
1070	// This is like toRawIter except it also evaluates the auxiliary expressions for ORDER BY.
1071	for {
1072		r, err := si.Next()
1073		if err == io.EOF {
1074			break
1075		} else if err != nil {
1076			return nil, nil, err
1077		}
1078		key, err := si.ec.evalExprList(aux)
1079		if err != nil {
1080			return nil, nil, err
1081		}
1082
1083		rows = append(rows, r.copyAllData())
1084		keys = append(keys, key)
1085	}
1086	return
1087}
1088
1089// externalRowSorter implements sort.Interface for a slice of rows
1090// with an external sort key.
1091type externalRowSorter struct {
1092	rows []row
1093	keys [][]interface{}
1094	desc []bool // may be nil
1095}
1096
1097func (ers externalRowSorter) Len() int { return len(ers.rows) }
1098func (ers externalRowSorter) Less(i, j int) bool {
1099	return compareValLists(ers.keys[i], ers.keys[j], ers.desc) < 0
1100}
1101func (ers externalRowSorter) Swap(i, j int) {
1102	ers.rows[i], ers.rows[j] = ers.rows[j], ers.rows[i]
1103	ers.keys[i], ers.keys[j] = ers.keys[j], ers.keys[i]
1104}
1105