1/* 2Copyright 2020 Google LLC 3 4Licensed under the Apache License, Version 2.0 (the "License"); 5you may not use this file except in compliance with the License. 6You may obtain a copy of the License at 7 8 http://www.apache.org/licenses/LICENSE-2.0 9 10Unless required by applicable law or agreed to in writing, software 11distributed under the License is distributed on an "AS IS" BASIS, 12WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13See the License for the specific language governing permissions and 14limitations under the License. 15*/ 16 17package spannertest 18 19import ( 20 "fmt" 21 "io" 22 "sort" 23 24 "cloud.google.com/go/spanner/spansql" 25) 26 27/* 28There's several ways to conceptualise SQL queries. The simplest, and what 29we implement here, is a series of pipelines that transform the data, whether 30pulling from a table (FROM tbl), filtering (WHERE expr), re-ordering (ORDER BY expr) 31or other transformations. 32 33The order of operations among those supported by Cloud Spanner is 34 FROM + JOIN + set ops [TODO: JOIN and set ops] 35 WHERE 36 GROUP BY 37 aggregation 38 HAVING [TODO] 39 SELECT 40 DISTINCT 41 ORDER BY 42 OFFSET [TODO] 43 LIMIT 44*/ 45 46// rowIter represents some iteration over rows of data. 47// It is returned by reads and queries. 48type rowIter interface { 49 // Cols returns the metadata about the returned data. 50 Cols() []colInfo 51 52 // Next returns the next row. 53 // If done, it returns (nil, io.EOF). 54 Next() (row, error) 55} 56 57// aggSentinel is a synthetic expression that refers to an aggregated value. 58// It is transient only; it is never stored and only used during evaluation. 59type aggSentinel struct { 60 spansql.Expr 61 Type spansql.Type 62 AggIndex int // Index+1 of SELECT list. 63} 64 65// nullIter is a rowIter that returns one empty row only. 66// This is used for queries without a table. 67type nullIter struct { 68 done bool 69} 70 71func (ni *nullIter) Cols() []colInfo { return nil } 72func (ni *nullIter) Next() (row, error) { 73 if ni.done { 74 return nil, io.EOF 75 } 76 ni.done = true 77 return nil, nil 78} 79 80// tableIter is a rowIter that walks a table. 81// It assumes the table is locked for the duration. 82type tableIter struct { 83 t *table 84 rowIndex int // index of next row to return 85} 86 87func (ti *tableIter) Cols() []colInfo { return ti.t.cols } 88func (ti *tableIter) Next() (row, error) { 89 if ti.rowIndex >= len(ti.t.rows) { 90 return nil, io.EOF 91 } 92 res := ti.t.rows[ti.rowIndex] 93 ti.rowIndex++ 94 return res, nil 95} 96 97// rawIter is a rowIter with fixed data. 98type rawIter struct { 99 // cols is the metadata about the returned data. 100 cols []colInfo 101 102 // rows holds the result data itself. 103 rows []row 104} 105 106func (raw *rawIter) Cols() []colInfo { return raw.cols } 107func (raw *rawIter) Next() (row, error) { 108 if len(raw.rows) == 0 { 109 return nil, io.EOF 110 } 111 res := raw.rows[0] 112 raw.rows = raw.rows[1:] 113 return res, nil 114} 115 116func (raw *rawIter) add(src row, colIndexes []int) { 117 raw.rows = append(raw.rows, src.copyData(colIndexes)) 118} 119 120func toRawIter(ri rowIter) (*rawIter, error) { 121 if raw, ok := ri.(*rawIter); ok { 122 return raw, nil 123 } 124 raw := &rawIter{cols: ri.Cols()} 125 for { 126 row, err := ri.Next() 127 if err == io.EOF { 128 break 129 } else if err != nil { 130 return nil, err 131 } 132 raw.rows = append(raw.rows, row) 133 } 134 return raw, nil 135} 136 137// whereIter applies a WHERE clause. 138type whereIter struct { 139 ri rowIter 140 ec evalContext 141 where spansql.BoolExpr 142} 143 144func (wi whereIter) Cols() []colInfo { return wi.ri.Cols() } 145func (wi whereIter) Next() (row, error) { 146 for { 147 row, err := wi.ri.Next() 148 if err != nil { 149 return nil, err 150 } 151 wi.ec.row = row 152 153 b, err := wi.ec.evalBoolExpr(wi.where) 154 if err != nil { 155 return nil, err 156 } 157 if !b { 158 continue 159 } 160 return row, nil 161 } 162} 163 164// selIter applies a SELECT list. 165type selIter struct { 166 ri rowIter 167 ec evalContext 168 cis []colInfo 169 list []spansql.Expr 170} 171 172func (si selIter) Cols() []colInfo { return si.cis } 173func (si selIter) Next() (row, error) { 174 row, err := si.ri.Next() 175 if err != nil { 176 return nil, err 177 } 178 si.ec.row = row 179 180 selectStar := len(si.list) == 1 && si.list[0] == spansql.Star 181 if selectStar { 182 return row, nil 183 } 184 185 return si.ec.evalExprList(si.list) 186} 187 188// distinctIter applies a DISTINCT filter. 189type distinctIter struct { 190 ri rowIter 191 seen []row 192} 193 194func (di *distinctIter) Cols() []colInfo { return di.ri.Cols() } 195func (di *distinctIter) Next() (row, error) { 196 // This is hilariously inefficient; O(N^2) in the number of returned rows. 197 // Some sort of hashing could be done to deduplicate instead. 198 // This also breaks on array/struct types. 199 for { 200 row, err := di.ri.Next() 201 if err != nil { 202 return nil, err 203 } 204 dupe := false 205 for _, prev := range di.seen { 206 if rowEqual(prev, row) { 207 dupe = true 208 break 209 } 210 } 211 if dupe { 212 continue 213 } 214 di.seen = append(di.seen, row) 215 return row, nil 216 } 217} 218 219// limitIter applies a LIMIT clause. 220type limitIter struct { 221 ri rowIter 222 rem int64 223} 224 225func (li *limitIter) Cols() []colInfo { return li.ri.Cols() } 226func (li *limitIter) Next() (row, error) { 227 if li.rem <= 0 { 228 return nil, io.EOF 229 } 230 row, err := li.ri.Next() 231 if err != nil { 232 return nil, err 233 } 234 li.rem-- 235 return row, nil 236} 237 238type queryParams map[string]interface{} 239 240func (d *database) Query(q spansql.Query, params queryParams) (rowIter, error) { 241 // If there's an ORDER BY clause, extend the query to include the expressions we need 242 // so they get evaluated during evalSelect. TODO: Is this actually okay? 243 var aux []spansql.Expr 244 var desc []bool 245 for _, o := range q.Order { 246 aux = append(aux, o.Expr) 247 desc = append(desc, o.Desc) 248 } 249 q.Select.List = append(q.Select.List, aux...) 250 251 ri, err := d.evalSelect(q.Select, params) 252 if err != nil { 253 return nil, err 254 } 255 256 // Apply ORDER BY. 257 if len(q.Order) > 0 { 258 raw, err := toRawIter(ri) 259 if err != nil { 260 return nil, err 261 } 262 sort.Slice(raw.rows, func(one, two int) bool { 263 r1, r2 := raw.rows[one], raw.rows[two] 264 aux1, aux2 := r1[len(r1)-len(aux):], r2[len(r2)-len(aux):] // sort keys 265 return compareValLists(aux1, aux2, desc) < 0 266 }) 267 // Remove ORDER BY values. 268 raw.cols = raw.cols[:len(raw.cols)-len(aux)] 269 for i, row := range raw.rows { 270 raw.rows[i] = row[:len(row)-len(aux)] 271 } 272 ri = raw 273 } 274 275 // TODO: OFFSET 276 277 // Apply LIMIT. 278 if q.Limit != nil { 279 lim, err := evalLimit(q.Limit, params) 280 if err != nil { 281 return nil, err 282 } 283 ri = &limitIter{ri: ri, rem: lim} 284 } 285 286 return ri, nil 287} 288 289func (d *database) evalSelect(sel spansql.Select, params queryParams) (ri rowIter, evalErr error) { 290 ri = &nullIter{} 291 ec := evalContext{ 292 params: params, 293 } 294 295 // First stage is to identify the data source. 296 // If there's a FROM then that names a table to use. 297 if len(sel.From) > 1 { 298 return nil, fmt.Errorf("selecting from more than one table not yet supported") 299 } 300 if len(sel.From) == 1 { 301 tableName := sel.From[0].Table 302 t, err := d.table(tableName) 303 if err != nil { 304 return nil, err 305 } 306 t.mu.Lock() 307 defer t.mu.Unlock() 308 ri = &tableIter{t: t} 309 ec.cols = t.cols 310 } 311 defer func() { 312 // If we're about to return a tableIter, convert it to a rawIter 313 // so that the table may be safely unlocked. 314 if evalErr == nil { 315 if ti, ok := ri.(*tableIter); ok { 316 ri, evalErr = toRawIter(ti) 317 } 318 } 319 }() 320 321 // Apply WHERE. 322 if sel.Where != nil { 323 ri = whereIter{ 324 ri: ri, 325 ec: ec, 326 where: sel.Where, 327 } 328 } 329 330 // Apply GROUP BY. 331 // This only reorders rows to group rows together; 332 // aggregation happens next. 333 var rowGroups [][2]int // Sequence of half-open intervals of row numbers. 334 if len(sel.GroupBy) > 0 { 335 raw, err := toRawIter(ri) 336 if err != nil { 337 return nil, err 338 } 339 keys := make([][]interface{}, 0, len(raw.rows)) 340 for _, row := range raw.rows { 341 // Evaluate sort key for this row. 342 // TODO: Support referring to expression names in the SELECT list; 343 // this may require passing through sel.List, or maybe mutating 344 // sel.GroupBy to copy the referenced values. This will also be 345 // required to support grouping by aliases. 346 ec.row = row 347 key, err := ec.evalExprList(sel.GroupBy) 348 if err != nil { 349 return nil, err 350 } 351 keys = append(keys, key) 352 } 353 354 // Reorder rows base on the evaluated keys. 355 ers := externalRowSorter{rows: raw.rows, keys: keys} 356 sort.Sort(ers) 357 raw.rows = ers.rows 358 359 // Record groups as a sequence of row intervals. 360 // Each group is a run of the same keys. 361 start := 0 362 for i := 1; i < len(keys); i++ { 363 if compareValLists(keys[i-1], keys[i], nil) == 0 { 364 continue 365 } 366 rowGroups = append(rowGroups, [2]int{start, i}) 367 start = i 368 } 369 if len(keys) > 0 { 370 rowGroups = append(rowGroups, [2]int{start, len(keys)}) 371 } 372 373 ri = raw 374 } 375 376 // Handle aggregation. 377 // TODO: Support more than one aggregation function; does Spanner support that? 378 aggI := -1 379 for i, e := range sel.List { 380 // Supported aggregate funcs have exactly one arg. 381 f, ok := e.(spansql.Func) 382 if !ok || len(f.Args) != 1 { 383 continue 384 } 385 _, ok = aggregateFuncs[f.Name] 386 if !ok { 387 continue 388 } 389 if aggI > -1 { 390 return nil, fmt.Errorf("only one aggregate function is supported") 391 } 392 aggI = i 393 } 394 if aggI > -1 { 395 raw, err := toRawIter(ri) 396 if err != nil { 397 return nil, err 398 } 399 if len(rowGroups) == 0 { 400 // No grouping, so aggregation applies to the entire table (e.g. COUNT(*)). 401 rowGroups = [][2]int{{0, len(raw.rows)}} 402 } 403 fexpr := sel.List[aggI].(spansql.Func) 404 fn := aggregateFuncs[fexpr.Name] 405 starArg := fexpr.Args[0] == spansql.Star 406 if starArg && !fn.AcceptStar { 407 return nil, fmt.Errorf("aggregate function %s does not accept * as an argument", fexpr.Name) 408 } 409 410 // Prepare output. 411 rawOut := &rawIter{ 412 // Same as input columns, but also the aggregate value. 413 // Add the colInfo for the aggregate at the end 414 // so we know the type. 415 // Make a copy for safety. 416 cols: append([]colInfo(nil), raw.cols...), 417 } 418 419 var aggType spansql.Type 420 for _, rg := range rowGroups { 421 // Compute aggregate value across this group. 422 var values []interface{} 423 for i := rg[0]; i < rg[1]; i++ { 424 ec.row = raw.rows[i] 425 if starArg { 426 // A non-NULL placeholder is sufficient for aggregation. 427 values = append(values, 1) 428 } else { 429 x, err := ec.evalExpr(fexpr.Args[0]) 430 if err != nil { 431 return nil, err 432 } 433 values = append(values, x) 434 } 435 } 436 x, typ, err := fn.Eval(values) 437 if err != nil { 438 return nil, err 439 } 440 aggType = typ 441 // Output for the row group is the first row of the group (arbitrary, 442 // but it should be representative), and the aggregate value. 443 // TODO: Should this exclude the aggregated expressions so they can't be selected? 444 repRow := raw.rows[rg[0]] 445 var outRow row 446 for i := range repRow { 447 outRow = append(outRow, repRow.copyDataElem(i)) 448 } 449 outRow = append(outRow, x) 450 rawOut.rows = append(rawOut.rows, outRow) 451 } 452 453 if aggType == (spansql.Type{}) { 454 // Fallback; there might not be any groups. 455 // TODO: Should this be in aggregateFunc? 456 aggType = int64Type 457 } 458 rawOut.cols = append(raw.cols, colInfo{ 459 Name: fexpr.SQL(), 460 Type: aggType, 461 AggIndex: aggI + 1, 462 }) 463 464 ri = rawOut 465 ec.cols = rawOut.cols 466 sel.List[aggI] = aggSentinel{ // Mutate query so evalExpr in selIter picks out the new value. 467 Type: aggType, 468 AggIndex: aggI + 1, 469 } 470 } 471 472 // TODO: Support table sampling. 473 474 // Apply SELECT list. 475 var colInfos []colInfo 476 // Is this a `SELECT *` query? 477 selectStar := len(sel.List) == 1 && sel.List[0] == spansql.Star 478 if selectStar { 479 // Every column will appear in the output. 480 colInfos = ec.cols 481 } else { 482 for _, e := range sel.List { 483 ci, err := ec.colInfo(e) 484 if err != nil { 485 return nil, err 486 } 487 // TODO: deal with ci.Name == ""? 488 colInfos = append(colInfos, ci) 489 } 490 } 491 ri = selIter{ 492 ri: ri, 493 ec: ec, 494 cis: colInfos, 495 list: sel.List, 496 } 497 498 // Apply DISTINCT. 499 if sel.Distinct { 500 ri = &distinctIter{ri: ri} 501 } 502 503 return ri, nil 504} 505 506// externalRowSorter implements sort.Interface for a slice of rows 507// with an external sort key. 508type externalRowSorter struct { 509 rows []row 510 keys [][]interface{} 511} 512 513func (ers externalRowSorter) Len() int { return len(ers.rows) } 514func (ers externalRowSorter) Less(i, j int) bool { 515 return compareValLists(ers.keys[i], ers.keys[j], nil) < 0 516} 517func (ers externalRowSorter) Swap(i, j int) { 518 ers.rows[i], ers.rows[j] = ers.rows[j], ers.rows[i] 519 ers.keys[i], ers.keys[j] = ers.keys[j], ers.keys[i] 520} 521