1package sqlbuilder
2
3import (
4	"context"
5	"database/sql"
6	"errors"
7	"fmt"
8	"strings"
9
10	db "upper.io/db.v3"
11	"upper.io/db.v3/internal/immutable"
12	"upper.io/db.v3/internal/sqladapter/exql"
13)
14
15type selectorQuery struct {
16	table     *exql.Columns
17	tableArgs []interface{}
18
19	distinct bool
20
21	where     *exql.Where
22	whereArgs []interface{}
23
24	groupBy     *exql.GroupBy
25	groupByArgs []interface{}
26
27	orderBy     *exql.OrderBy
28	orderByArgs []interface{}
29
30	limit  exql.Limit
31	offset exql.Offset
32
33	columns     *exql.Columns
34	columnsArgs []interface{}
35
36	joins     []*exql.Join
37	joinsArgs []interface{}
38
39	amendFn func(string) string
40}
41
42func (sq *selectorQuery) and(b *sqlBuilder, terms ...interface{}) error {
43	where, whereArgs := b.t.toWhereWithArguments(terms)
44
45	if sq.where == nil {
46		sq.where, sq.whereArgs = &exql.Where{}, []interface{}{}
47	}
48	sq.where.Append(&where)
49	sq.whereArgs = append(sq.whereArgs, whereArgs...)
50
51	return nil
52}
53
54func (sq *selectorQuery) arguments() []interface{} {
55	return joinArguments(
56		sq.columnsArgs,
57		sq.tableArgs,
58		sq.joinsArgs,
59		sq.whereArgs,
60		sq.groupByArgs,
61		sq.orderByArgs,
62	)
63}
64
65func (sq *selectorQuery) statement() *exql.Statement {
66	stmt := &exql.Statement{
67		Type:     exql.Select,
68		Table:    sq.table,
69		Columns:  sq.columns,
70		Distinct: sq.distinct,
71		Limit:    sq.limit,
72		Offset:   sq.offset,
73		Where:    sq.where,
74		OrderBy:  sq.orderBy,
75		GroupBy:  sq.groupBy,
76	}
77
78	if len(sq.joins) > 0 {
79		stmt.Joins = exql.JoinConditions(sq.joins...)
80	}
81
82	stmt.SetAmendment(sq.amendFn)
83
84	return stmt
85}
86
87func (sq *selectorQuery) pushJoin(t string, tables []interface{}) error {
88	fragments, args, err := columnFragments(tables)
89	if err != nil {
90		return err
91	}
92
93	if sq.joins == nil {
94		sq.joins = []*exql.Join{}
95	}
96	sq.joins = append(sq.joins,
97		&exql.Join{
98			Type:  t,
99			Table: exql.JoinColumns(fragments...),
100		},
101	)
102
103	sq.joinsArgs = append(sq.joinsArgs, args...)
104
105	return nil
106}
107
108type selector struct {
109	builder *sqlBuilder
110
111	fn   func(*selectorQuery) error
112	prev *selector
113}
114
115var _ = immutable.Immutable(&selector{})
116
117func (sel *selector) SQLBuilder() *sqlBuilder {
118	if sel.prev == nil {
119		return sel.builder
120	}
121	return sel.prev.SQLBuilder()
122}
123
124func (sel *selector) String() string {
125	s, err := sel.Compile()
126	if err != nil {
127		panic(err.Error())
128	}
129	return prepareQueryForDisplay(s)
130}
131
132func (sel *selector) frame(fn func(*selectorQuery) error) *selector {
133	return &selector{prev: sel, fn: fn}
134}
135
136func (sel *selector) clone() Selector {
137	return sel.frame(func(*selectorQuery) error {
138		return nil
139	})
140}
141
142func (sel *selector) From(tables ...interface{}) Selector {
143	return sel.frame(
144		func(sq *selectorQuery) error {
145			fragments, args, err := columnFragments(tables)
146			if err != nil {
147				return err
148			}
149			sq.table = exql.JoinColumns(fragments...)
150			sq.tableArgs = args
151			return nil
152		},
153	)
154}
155
156func (sel *selector) setColumns(columns ...interface{}) Selector {
157	return sel.frame(func(sq *selectorQuery) error {
158		sq.columns = nil
159		return sq.pushColumns(columns...)
160	})
161}
162
163func (sel *selector) Columns(columns ...interface{}) Selector {
164	return sel.frame(func(sq *selectorQuery) error {
165		return sq.pushColumns(columns...)
166	})
167}
168
169func (sq *selectorQuery) pushColumns(columns ...interface{}) error {
170	f, args, err := columnFragments(columns)
171	if err != nil {
172		return err
173	}
174
175	c := exql.JoinColumns(f...)
176
177	if sq.columns != nil {
178		sq.columns.Append(c)
179	} else {
180		sq.columns = c
181	}
182
183	sq.columnsArgs = append(sq.columnsArgs, args...)
184	return nil
185}
186
187func (sel *selector) Distinct(exps ...interface{}) Selector {
188	return sel.frame(func(sq *selectorQuery) error {
189		sq.distinct = true
190		return sq.pushColumns(exps...)
191	})
192}
193
194func (sel *selector) Where(terms ...interface{}) Selector {
195	return sel.frame(func(sq *selectorQuery) error {
196		if len(terms) == 1 && terms[0] == nil {
197			sq.where, sq.whereArgs = &exql.Where{}, []interface{}{}
198			return nil
199		}
200		return sq.and(sel.SQLBuilder(), terms...)
201	})
202}
203
204func (sel *selector) And(terms ...interface{}) Selector {
205	return sel.frame(func(sq *selectorQuery) error {
206		return sq.and(sel.SQLBuilder(), terms...)
207	})
208}
209
210func (sel *selector) Amend(fn func(string) string) Selector {
211	return sel.frame(func(sq *selectorQuery) error {
212		sq.amendFn = fn
213		return nil
214	})
215}
216
217func (sel *selector) Arguments() []interface{} {
218	sq, err := sel.build()
219	if err != nil {
220		return nil
221	}
222	return sq.arguments()
223}
224
225func (sel *selector) GroupBy(columns ...interface{}) Selector {
226	return sel.frame(func(sq *selectorQuery) error {
227		fragments, args, err := columnFragments(columns)
228		if err != nil {
229			return err
230		}
231
232		if fragments != nil {
233			sq.groupBy = exql.GroupByColumns(fragments...)
234		}
235		sq.groupByArgs = args
236
237		return nil
238	})
239}
240
241func (sel *selector) OrderBy(columns ...interface{}) Selector {
242	return sel.frame(func(sq *selectorQuery) error {
243
244		if len(columns) == 1 && columns[0] == nil {
245			sq.orderBy = nil
246			sq.orderByArgs = nil
247			return nil
248		}
249
250		var sortColumns exql.SortColumns
251
252		for i := range columns {
253			var sort *exql.SortColumn
254
255			switch value := columns[i].(type) {
256			case db.RawValue:
257				query, args := Preprocess(value.Raw(), value.Arguments())
258				sort = &exql.SortColumn{
259					Column: exql.RawValue(query),
260				}
261				sq.orderByArgs = append(sq.orderByArgs, args...)
262			case db.Function:
263				fnName, fnArgs := value.Name(), value.Arguments()
264				if len(fnArgs) == 0 {
265					fnName = fnName + "()"
266				} else {
267					fnName = fnName + "(?" + strings.Repeat("?, ", len(fnArgs)-1) + ")"
268				}
269				fnName, fnArgs = Preprocess(fnName, fnArgs)
270				sort = &exql.SortColumn{
271					Column: exql.RawValue(fnName),
272				}
273				sq.orderByArgs = append(sq.orderByArgs, fnArgs...)
274			case string:
275				if strings.HasPrefix(value, "-") {
276					sort = &exql.SortColumn{
277						Column: exql.ColumnWithName(value[1:]),
278						Order:  exql.Descendent,
279					}
280				} else {
281					chunks := strings.SplitN(value, " ", 2)
282
283					order := exql.Ascendent
284					if len(chunks) > 1 && strings.ToUpper(chunks[1]) == "DESC" {
285						order = exql.Descendent
286					}
287
288					sort = &exql.SortColumn{
289						Column: exql.ColumnWithName(chunks[0]),
290						Order:  order,
291					}
292				}
293			default:
294				return fmt.Errorf("Can't sort by type %T", value)
295			}
296			sortColumns.Columns = append(sortColumns.Columns, sort)
297		}
298
299		sq.orderBy = &exql.OrderBy{
300			SortColumns: &sortColumns,
301		}
302		return nil
303	})
304}
305
306func (sel *selector) Using(columns ...interface{}) Selector {
307	return sel.frame(func(sq *selectorQuery) error {
308
309		joins := len(sq.joins)
310		if joins == 0 {
311			return errors.New(`cannot use Using() without a preceding Join() expression`)
312		}
313
314		lastJoin := sq.joins[joins-1]
315		if lastJoin.On != nil {
316			return errors.New(`cannot use Using() and On() with the same Join() expression`)
317		}
318
319		fragments, args, err := columnFragments(columns)
320		if err != nil {
321			return err
322		}
323
324		sq.joinsArgs = append(sq.joinsArgs, args...)
325		lastJoin.Using = exql.UsingColumns(fragments...)
326
327		return nil
328	})
329}
330
331func (sel *selector) FullJoin(tables ...interface{}) Selector {
332	return sel.frame(func(sq *selectorQuery) error {
333		return sq.pushJoin("FULL", tables)
334	})
335}
336
337func (sel *selector) CrossJoin(tables ...interface{}) Selector {
338	return sel.frame(func(sq *selectorQuery) error {
339		return sq.pushJoin("CROSS", tables)
340	})
341}
342
343func (sel *selector) RightJoin(tables ...interface{}) Selector {
344	return sel.frame(func(sq *selectorQuery) error {
345		return sq.pushJoin("RIGHT", tables)
346	})
347}
348
349func (sel *selector) LeftJoin(tables ...interface{}) Selector {
350	return sel.frame(func(sq *selectorQuery) error {
351		return sq.pushJoin("LEFT", tables)
352	})
353}
354
355func (sel *selector) Join(tables ...interface{}) Selector {
356	return sel.frame(func(sq *selectorQuery) error {
357		return sq.pushJoin("", tables)
358	})
359}
360
361func (sel *selector) On(terms ...interface{}) Selector {
362	return sel.frame(func(sq *selectorQuery) error {
363		joins := len(sq.joins)
364
365		if joins == 0 {
366			return errors.New(`cannot use On() without a preceding Join() expression`)
367		}
368
369		lastJoin := sq.joins[joins-1]
370		if lastJoin.On != nil {
371			return errors.New(`cannot use Using() and On() with the same Join() expression`)
372		}
373
374		w, a := sel.SQLBuilder().t.toWhereWithArguments(terms)
375		o := exql.On(w)
376
377		lastJoin.On = &o
378
379		sq.joinsArgs = append(sq.joinsArgs, a...)
380
381		return nil
382	})
383}
384
385func (sel *selector) Limit(n int) Selector {
386	return sel.frame(func(sq *selectorQuery) error {
387		if n < 0 {
388			n = 0
389		}
390		sq.limit = exql.Limit(n)
391		return nil
392	})
393}
394
395func (sel *selector) Offset(n int) Selector {
396	return sel.frame(func(sq *selectorQuery) error {
397		if n < 0 {
398			n = 0
399		}
400		sq.offset = exql.Offset(n)
401		return nil
402	})
403}
404
405func (sel *selector) template() *exql.Template {
406	return sel.SQLBuilder().t.Template
407}
408
409func (sel *selector) As(alias string) Selector {
410	return sel.frame(func(sq *selectorQuery) error {
411		if sq.table == nil {
412			return errors.New("Cannot use As() without a preceding From() expression")
413		}
414		last := len(sq.table.Columns) - 1
415		if raw, ok := sq.table.Columns[last].(*exql.Raw); ok {
416			compiled, err := exql.ColumnWithName(alias).Compile(sel.template())
417			if err != nil {
418				return err
419			}
420			sq.table.Columns[last] = exql.RawValue(raw.Value + " AS " + compiled)
421		}
422		return nil
423	})
424}
425
426func (sel *selector) statement() *exql.Statement {
427	sq, _ := sel.build()
428	return sq.statement()
429}
430
431func (sel *selector) QueryRow() (*sql.Row, error) {
432	return sel.QueryRowContext(sel.SQLBuilder().sess.Context())
433}
434
435func (sel *selector) QueryRowContext(ctx context.Context) (*sql.Row, error) {
436	sq, err := sel.build()
437	if err != nil {
438		return nil, err
439	}
440
441	return sel.SQLBuilder().sess.StatementQueryRow(ctx, sq.statement(), sq.arguments()...)
442}
443
444func (sel *selector) Prepare() (*sql.Stmt, error) {
445	return sel.PrepareContext(sel.SQLBuilder().sess.Context())
446}
447
448func (sel *selector) PrepareContext(ctx context.Context) (*sql.Stmt, error) {
449	sq, err := sel.build()
450	if err != nil {
451		return nil, err
452	}
453	return sel.SQLBuilder().sess.StatementPrepare(ctx, sq.statement())
454}
455
456func (sel *selector) Query() (*sql.Rows, error) {
457	return sel.QueryContext(sel.SQLBuilder().sess.Context())
458}
459
460func (sel *selector) QueryContext(ctx context.Context) (*sql.Rows, error) {
461	sq, err := sel.build()
462	if err != nil {
463		return nil, err
464	}
465	return sel.SQLBuilder().sess.StatementQuery(ctx, sq.statement(), sq.arguments()...)
466}
467
468func (sel *selector) Iterator() Iterator {
469	return sel.IteratorContext(sel.SQLBuilder().sess.Context())
470}
471
472func (sel *selector) IteratorContext(ctx context.Context) Iterator {
473	sess := sel.SQLBuilder().sess
474	sq, err := sel.build()
475	if err != nil {
476		return &iterator{sess, nil, err}
477	}
478
479	rows, err := sess.StatementQuery(ctx, sq.statement(), sq.arguments()...)
480	return &iterator{sess, rows, err}
481}
482
483func (sel *selector) Paginate(pageSize uint) Paginator {
484	return newPaginator(sel.clone(), pageSize)
485}
486
487func (sel *selector) All(destSlice interface{}) error {
488	return sel.Iterator().All(destSlice)
489}
490
491func (sel *selector) One(dest interface{}) error {
492	return sel.Iterator().One(dest)
493}
494
495func (sel *selector) build() (*selectorQuery, error) {
496	sq, err := immutable.FastForward(sel)
497	if err != nil {
498		return nil, err
499	}
500	return sq.(*selectorQuery), nil
501}
502
503func (sel *selector) Compile() (string, error) {
504	return sel.statement().Compile(sel.template())
505}
506
507func (sel *selector) Prev() immutable.Immutable {
508	if sel == nil {
509		return nil
510	}
511	return sel.prev
512}
513
514func (sel *selector) Fn(in interface{}) error {
515	if sel.fn == nil {
516		return nil
517	}
518	return sel.fn(in.(*selectorQuery))
519}
520
521func (sel *selector) Base() interface{} {
522	return &selectorQuery{}
523}
524