1/*
2Copyright 2020 Google LLC
3
4Licensed under the Apache License, Version 2.0 (the "License");
5you may not use this file except in compliance with the License.
6You may obtain a copy of the License at
7
8    http://www.apache.org/licenses/LICENSE-2.0
9
10Unless required by applicable law or agreed to in writing, software
11distributed under the License is distributed on an "AS IS" BASIS,
12WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13See the License for the specific language governing permissions and
14limitations under the License.
15*/
16
17package spannertest
18
19import (
20	"fmt"
21	"io"
22	"sort"
23
24	"cloud.google.com/go/spanner/spansql"
25)
26
27/*
28There's several ways to conceptualise SQL queries. The simplest, and what
29we implement here, is a series of pipelines that transform the data, whether
30pulling from a table (FROM tbl), filtering (WHERE expr), re-ordering (ORDER BY expr)
31or other transformations.
32
33The order of operations among those supported by Cloud Spanner is
34	FROM + JOIN + set ops [TODO: JOIN and set ops]
35	WHERE
36	GROUP BY
37	aggregation
38	HAVING [TODO]
39	SELECT
40	DISTINCT
41	ORDER BY
42	OFFSET [TODO]
43	LIMIT
44*/
45
46// rowIter represents some iteration over rows of data.
47// It is returned by reads and queries.
48type rowIter interface {
49	// Cols returns the metadata about the returned data.
50	Cols() []colInfo
51
52	// Next returns the next row.
53	// If done, it returns (nil, io.EOF).
54	Next() (row, error)
55}
56
57// aggSentinel is a synthetic expression that refers to an aggregated value.
58// It is transient only; it is never stored and only used during evaluation.
59type aggSentinel struct {
60	spansql.Expr
61	Type     spansql.Type
62	AggIndex int // Index+1 of SELECT list.
63}
64
65// nullIter is a rowIter that returns one empty row only.
66// This is used for queries without a table.
67type nullIter struct {
68	done bool
69}
70
71func (ni *nullIter) Cols() []colInfo { return nil }
72func (ni *nullIter) Next() (row, error) {
73	if ni.done {
74		return nil, io.EOF
75	}
76	ni.done = true
77	return nil, nil
78}
79
80// tableIter is a rowIter that walks a table.
81// It assumes the table is locked for the duration.
82type tableIter struct {
83	t        *table
84	rowIndex int // index of next row to return
85}
86
87func (ti *tableIter) Cols() []colInfo { return ti.t.cols }
88func (ti *tableIter) Next() (row, error) {
89	if ti.rowIndex >= len(ti.t.rows) {
90		return nil, io.EOF
91	}
92	res := ti.t.rows[ti.rowIndex]
93	ti.rowIndex++
94	return res, nil
95}
96
97// rawIter is a rowIter with fixed data.
98type rawIter struct {
99	// cols is the metadata about the returned data.
100	cols []colInfo
101
102	// rows holds the result data itself.
103	rows []row
104}
105
106func (raw *rawIter) Cols() []colInfo { return raw.cols }
107func (raw *rawIter) Next() (row, error) {
108	if len(raw.rows) == 0 {
109		return nil, io.EOF
110	}
111	res := raw.rows[0]
112	raw.rows = raw.rows[1:]
113	return res, nil
114}
115
116func (raw *rawIter) add(src row, colIndexes []int) {
117	raw.rows = append(raw.rows, src.copyData(colIndexes))
118}
119
120func toRawIter(ri rowIter) (*rawIter, error) {
121	if raw, ok := ri.(*rawIter); ok {
122		return raw, nil
123	}
124	raw := &rawIter{cols: ri.Cols()}
125	for {
126		row, err := ri.Next()
127		if err == io.EOF {
128			break
129		} else if err != nil {
130			return nil, err
131		}
132		raw.rows = append(raw.rows, row)
133	}
134	return raw, nil
135}
136
137// whereIter applies a WHERE clause.
138type whereIter struct {
139	ri    rowIter
140	ec    evalContext
141	where spansql.BoolExpr
142}
143
144func (wi whereIter) Cols() []colInfo { return wi.ri.Cols() }
145func (wi whereIter) Next() (row, error) {
146	for {
147		row, err := wi.ri.Next()
148		if err != nil {
149			return nil, err
150		}
151		wi.ec.row = row
152
153		b, err := wi.ec.evalBoolExpr(wi.where)
154		if err != nil {
155			return nil, err
156		}
157		if !b {
158			continue
159		}
160		return row, nil
161	}
162}
163
164// selIter applies a SELECT list.
165type selIter struct {
166	ri   rowIter
167	ec   evalContext
168	cis  []colInfo
169	list []spansql.Expr
170}
171
172func (si selIter) Cols() []colInfo { return si.cis }
173func (si selIter) Next() (row, error) {
174	row, err := si.ri.Next()
175	if err != nil {
176		return nil, err
177	}
178	si.ec.row = row
179
180	selectStar := len(si.list) == 1 && si.list[0] == spansql.Star
181	if selectStar {
182		return row, nil
183	}
184
185	return si.ec.evalExprList(si.list)
186}
187
188// distinctIter applies a DISTINCT filter.
189type distinctIter struct {
190	ri   rowIter
191	seen []row
192}
193
194func (di *distinctIter) Cols() []colInfo { return di.ri.Cols() }
195func (di *distinctIter) Next() (row, error) {
196	// This is hilariously inefficient; O(N^2) in the number of returned rows.
197	// Some sort of hashing could be done to deduplicate instead.
198	// This also breaks on array/struct types.
199	for {
200		row, err := di.ri.Next()
201		if err != nil {
202			return nil, err
203		}
204		dupe := false
205		for _, prev := range di.seen {
206			if rowEqual(prev, row) {
207				dupe = true
208				break
209			}
210		}
211		if dupe {
212			continue
213		}
214		di.seen = append(di.seen, row)
215		return row, nil
216	}
217}
218
219// limitIter applies a LIMIT clause.
220type limitIter struct {
221	ri  rowIter
222	rem int64
223}
224
225func (li *limitIter) Cols() []colInfo { return li.ri.Cols() }
226func (li *limitIter) Next() (row, error) {
227	if li.rem <= 0 {
228		return nil, io.EOF
229	}
230	row, err := li.ri.Next()
231	if err != nil {
232		return nil, err
233	}
234	li.rem--
235	return row, nil
236}
237
238type queryParams map[string]interface{}
239
240func (d *database) Query(q spansql.Query, params queryParams) (rowIter, error) {
241	// If there's an ORDER BY clause, extend the query to include the expressions we need
242	// so they get evaluated during evalSelect. TODO: Is this actually okay?
243	var aux []spansql.Expr
244	var desc []bool
245	for _, o := range q.Order {
246		aux = append(aux, o.Expr)
247		desc = append(desc, o.Desc)
248	}
249	q.Select.List = append(q.Select.List, aux...)
250
251	ri, err := d.evalSelect(q.Select, params)
252	if err != nil {
253		return nil, err
254	}
255
256	// Apply ORDER BY.
257	if len(q.Order) > 0 {
258		raw, err := toRawIter(ri)
259		if err != nil {
260			return nil, err
261		}
262		sort.Slice(raw.rows, func(one, two int) bool {
263			r1, r2 := raw.rows[one], raw.rows[two]
264			aux1, aux2 := r1[len(r1)-len(aux):], r2[len(r2)-len(aux):] // sort keys
265			return compareValLists(aux1, aux2, desc) < 0
266		})
267		// Remove ORDER BY values.
268		raw.cols = raw.cols[:len(raw.cols)-len(aux)]
269		for i, row := range raw.rows {
270			raw.rows[i] = row[:len(row)-len(aux)]
271		}
272		ri = raw
273	}
274
275	// TODO: OFFSET
276
277	// Apply LIMIT.
278	if q.Limit != nil {
279		lim, err := evalLimit(q.Limit, params)
280		if err != nil {
281			return nil, err
282		}
283		ri = &limitIter{ri: ri, rem: lim}
284	}
285
286	return ri, nil
287}
288
289func (d *database) evalSelect(sel spansql.Select, params queryParams) (ri rowIter, evalErr error) {
290	ri = &nullIter{}
291	ec := evalContext{
292		params: params,
293	}
294
295	// First stage is to identify the data source.
296	// If there's a FROM then that names a table to use.
297	if len(sel.From) > 1 {
298		return nil, fmt.Errorf("selecting from more than one table not yet supported")
299	}
300	if len(sel.From) == 1 {
301		tableName := sel.From[0].Table
302		t, err := d.table(tableName)
303		if err != nil {
304			return nil, err
305		}
306		t.mu.Lock()
307		defer t.mu.Unlock()
308		ri = &tableIter{t: t}
309		ec.cols = t.cols
310	}
311	defer func() {
312		// If we're about to return a tableIter, convert it to a rawIter
313		// so that the table may be safely unlocked.
314		if evalErr == nil {
315			if ti, ok := ri.(*tableIter); ok {
316				ri, evalErr = toRawIter(ti)
317			}
318		}
319	}()
320
321	// Apply WHERE.
322	if sel.Where != nil {
323		ri = whereIter{
324			ri:    ri,
325			ec:    ec,
326			where: sel.Where,
327		}
328	}
329
330	// Apply GROUP BY.
331	// This only reorders rows to group rows together;
332	// aggregation happens next.
333	var rowGroups [][2]int // Sequence of half-open intervals of row numbers.
334	if len(sel.GroupBy) > 0 {
335		raw, err := toRawIter(ri)
336		if err != nil {
337			return nil, err
338		}
339		keys := make([][]interface{}, 0, len(raw.rows))
340		for _, row := range raw.rows {
341			// Evaluate sort key for this row.
342			// TODO: Support referring to expression names in the SELECT list;
343			// this may require passing through sel.List, or maybe mutating
344			// sel.GroupBy to copy the referenced values. This will also be
345			// required to support grouping by aliases.
346			ec.row = row
347			key, err := ec.evalExprList(sel.GroupBy)
348			if err != nil {
349				return nil, err
350			}
351			keys = append(keys, key)
352		}
353
354		// Reorder rows base on the evaluated keys.
355		ers := externalRowSorter{rows: raw.rows, keys: keys}
356		sort.Sort(ers)
357		raw.rows = ers.rows
358
359		// Record groups as a sequence of row intervals.
360		// Each group is a run of the same keys.
361		start := 0
362		for i := 1; i < len(keys); i++ {
363			if compareValLists(keys[i-1], keys[i], nil) == 0 {
364				continue
365			}
366			rowGroups = append(rowGroups, [2]int{start, i})
367			start = i
368		}
369		if len(keys) > 0 {
370			rowGroups = append(rowGroups, [2]int{start, len(keys)})
371		}
372
373		ri = raw
374	}
375
376	// Handle aggregation.
377	// TODO: Support more than one aggregation function; does Spanner support that?
378	aggI := -1
379	for i, e := range sel.List {
380		// Supported aggregate funcs have exactly one arg.
381		f, ok := e.(spansql.Func)
382		if !ok || len(f.Args) != 1 {
383			continue
384		}
385		_, ok = aggregateFuncs[f.Name]
386		if !ok {
387			continue
388		}
389		if aggI > -1 {
390			return nil, fmt.Errorf("only one aggregate function is supported")
391		}
392		aggI = i
393	}
394	if aggI > -1 {
395		raw, err := toRawIter(ri)
396		if err != nil {
397			return nil, err
398		}
399		if len(rowGroups) == 0 {
400			// No grouping, so aggregation applies to the entire table (e.g. COUNT(*)).
401			rowGroups = [][2]int{{0, len(raw.rows)}}
402		}
403		fexpr := sel.List[aggI].(spansql.Func)
404		fn := aggregateFuncs[fexpr.Name]
405		starArg := fexpr.Args[0] == spansql.Star
406		if starArg && !fn.AcceptStar {
407			return nil, fmt.Errorf("aggregate function %s does not accept * as an argument", fexpr.Name)
408		}
409
410		// Prepare output.
411		rawOut := &rawIter{
412			// Same as input columns, but also the aggregate value.
413			// Add the colInfo for the aggregate at the end
414			// so we know the type.
415			// Make a copy for safety.
416			cols: append([]colInfo(nil), raw.cols...),
417		}
418
419		var aggType spansql.Type
420		for _, rg := range rowGroups {
421			// Compute aggregate value across this group.
422			var values []interface{}
423			for i := rg[0]; i < rg[1]; i++ {
424				ec.row = raw.rows[i]
425				if starArg {
426					// A non-NULL placeholder is sufficient for aggregation.
427					values = append(values, 1)
428				} else {
429					x, err := ec.evalExpr(fexpr.Args[0])
430					if err != nil {
431						return nil, err
432					}
433					values = append(values, x)
434				}
435			}
436			x, typ, err := fn.Eval(values)
437			if err != nil {
438				return nil, err
439			}
440			aggType = typ
441			// Output for the row group is the first row of the group (arbitrary,
442			// but it should be representative), and the aggregate value.
443			// TODO: Should this exclude the aggregated expressions so they can't be selected?
444			repRow := raw.rows[rg[0]]
445			var outRow row
446			for i := range repRow {
447				outRow = append(outRow, repRow.copyDataElem(i))
448			}
449			outRow = append(outRow, x)
450			rawOut.rows = append(rawOut.rows, outRow)
451		}
452
453		if aggType == (spansql.Type{}) {
454			// Fallback; there might not be any groups.
455			// TODO: Should this be in aggregateFunc?
456			aggType = int64Type
457		}
458		rawOut.cols = append(raw.cols, colInfo{
459			Name:     fexpr.SQL(),
460			Type:     aggType,
461			AggIndex: aggI + 1,
462		})
463
464		ri = rawOut
465		ec.cols = rawOut.cols
466		sel.List[aggI] = aggSentinel{ // Mutate query so evalExpr in selIter picks out the new value.
467			Type:     aggType,
468			AggIndex: aggI + 1,
469		}
470	}
471
472	// TODO: Support table sampling.
473
474	// Apply SELECT list.
475	var colInfos []colInfo
476	// Is this a `SELECT *` query?
477	selectStar := len(sel.List) == 1 && sel.List[0] == spansql.Star
478	if selectStar {
479		// Every column will appear in the output.
480		colInfos = ec.cols
481	} else {
482		for _, e := range sel.List {
483			ci, err := ec.colInfo(e)
484			if err != nil {
485				return nil, err
486			}
487			// TODO: deal with ci.Name == ""?
488			colInfos = append(colInfos, ci)
489		}
490	}
491	ri = selIter{
492		ri:   ri,
493		ec:   ec,
494		cis:  colInfos,
495		list: sel.List,
496	}
497
498	// Apply DISTINCT.
499	if sel.Distinct {
500		ri = &distinctIter{ri: ri}
501	}
502
503	return ri, nil
504}
505
506// externalRowSorter implements sort.Interface for a slice of rows
507// with an external sort key.
508type externalRowSorter struct {
509	rows []row
510	keys [][]interface{}
511}
512
513func (ers externalRowSorter) Len() int { return len(ers.rows) }
514func (ers externalRowSorter) Less(i, j int) bool {
515	return compareValLists(ers.keys[i], ers.keys[j], nil) < 0
516}
517func (ers externalRowSorter) Swap(i, j int) {
518	ers.rows[i], ers.rows[j] = ers.rows[j], ers.rows[i]
519	ers.keys[i], ers.keys[j] = ers.keys[j], ers.keys[i]
520}
521