1package squirrel
2
3import (
4	"bytes"
5	"database/sql"
6	"fmt"
7	"sort"
8	"strings"
9
10	"github.com/lann/builder"
11)
12
13type updateData struct {
14	PlaceholderFormat PlaceholderFormat
15	RunWith           BaseRunner
16	Prefixes          []Sqlizer
17	Table             string
18	SetClauses        []setClause
19	WhereParts        []Sqlizer
20	OrderBys          []string
21	Limit             string
22	Offset            string
23	Suffixes          []Sqlizer
24}
25
26type setClause struct {
27	column string
28	value  interface{}
29}
30
31func (d *updateData) Exec() (sql.Result, error) {
32	if d.RunWith == nil {
33		return nil, RunnerNotSet
34	}
35	return ExecWith(d.RunWith, d)
36}
37
38func (d *updateData) Query() (*sql.Rows, error) {
39	if d.RunWith == nil {
40		return nil, RunnerNotSet
41	}
42	return QueryWith(d.RunWith, d)
43}
44
45func (d *updateData) QueryRow() RowScanner {
46	if d.RunWith == nil {
47		return &Row{err: RunnerNotSet}
48	}
49	queryRower, ok := d.RunWith.(QueryRower)
50	if !ok {
51		return &Row{err: RunnerNotQueryRunner}
52	}
53	return QueryRowWith(queryRower, d)
54}
55
56func (d *updateData) ToSql() (sqlStr string, args []interface{}, err error) {
57	if len(d.Table) == 0 {
58		err = fmt.Errorf("update statements must specify a table")
59		return
60	}
61	if len(d.SetClauses) == 0 {
62		err = fmt.Errorf("update statements must have at least one Set clause")
63		return
64	}
65
66	sql := &bytes.Buffer{}
67
68	if len(d.Prefixes) > 0 {
69		args, err = appendToSql(d.Prefixes, sql, " ", args)
70		if err != nil {
71			return
72		}
73
74		sql.WriteString(" ")
75	}
76
77	sql.WriteString("UPDATE ")
78	sql.WriteString(d.Table)
79
80	sql.WriteString(" SET ")
81	setSqls := make([]string, len(d.SetClauses))
82	for i, setClause := range d.SetClauses {
83		var valSql string
84		if vs, ok := setClause.value.(Sqlizer); ok {
85			vsql, vargs, err := vs.ToSql()
86			if err != nil {
87				return "", nil, err
88			}
89			if _, ok := vs.(SelectBuilder); ok {
90				valSql = fmt.Sprintf("(%s)", vsql)
91			} else {
92				valSql = vsql
93			}
94			args = append(args, vargs...)
95		} else {
96			valSql = "?"
97			args = append(args, setClause.value)
98		}
99		setSqls[i] = fmt.Sprintf("%s = %s", setClause.column, valSql)
100	}
101	sql.WriteString(strings.Join(setSqls, ", "))
102
103	if len(d.WhereParts) > 0 {
104		sql.WriteString(" WHERE ")
105		args, err = appendToSql(d.WhereParts, sql, " AND ", args)
106		if err != nil {
107			return
108		}
109	}
110
111	if len(d.OrderBys) > 0 {
112		sql.WriteString(" ORDER BY ")
113		sql.WriteString(strings.Join(d.OrderBys, ", "))
114	}
115
116	if len(d.Limit) > 0 {
117		sql.WriteString(" LIMIT ")
118		sql.WriteString(d.Limit)
119	}
120
121	if len(d.Offset) > 0 {
122		sql.WriteString(" OFFSET ")
123		sql.WriteString(d.Offset)
124	}
125
126	if len(d.Suffixes) > 0 {
127		sql.WriteString(" ")
128		args, err = appendToSql(d.Suffixes, sql, " ", args)
129		if err != nil {
130			return
131		}
132	}
133
134	sqlStr, err = d.PlaceholderFormat.ReplacePlaceholders(sql.String())
135	return
136}
137
138// Builder
139
140// UpdateBuilder builds SQL UPDATE statements.
141type UpdateBuilder builder.Builder
142
143func init() {
144	builder.Register(UpdateBuilder{}, updateData{})
145}
146
147// Format methods
148
149// PlaceholderFormat sets PlaceholderFormat (e.g. Question or Dollar) for the
150// query.
151func (b UpdateBuilder) PlaceholderFormat(f PlaceholderFormat) UpdateBuilder {
152	return builder.Set(b, "PlaceholderFormat", f).(UpdateBuilder)
153}
154
155// Runner methods
156
157// RunWith sets a Runner (like database/sql.DB) to be used with e.g. Exec.
158func (b UpdateBuilder) RunWith(runner BaseRunner) UpdateBuilder {
159	return setRunWith(b, runner).(UpdateBuilder)
160}
161
162// Exec builds and Execs the query with the Runner set by RunWith.
163func (b UpdateBuilder) Exec() (sql.Result, error) {
164	data := builder.GetStruct(b).(updateData)
165	return data.Exec()
166}
167
168func (b UpdateBuilder) Query() (*sql.Rows, error) {
169	data := builder.GetStruct(b).(updateData)
170	return data.Query()
171}
172
173func (b UpdateBuilder) QueryRow() RowScanner {
174	data := builder.GetStruct(b).(updateData)
175	return data.QueryRow()
176}
177
178func (b UpdateBuilder) Scan(dest ...interface{}) error {
179	return b.QueryRow().Scan(dest...)
180}
181
182// SQL methods
183
184// ToSql builds the query into a SQL string and bound args.
185func (b UpdateBuilder) ToSql() (string, []interface{}, error) {
186	data := builder.GetStruct(b).(updateData)
187	return data.ToSql()
188}
189
190// Prefix adds an expression to the beginning of the query
191func (b UpdateBuilder) Prefix(sql string, args ...interface{}) UpdateBuilder {
192	return b.PrefixExpr(Expr(sql, args...))
193}
194
195// PrefixExpr adds an expression to the very beginning of the query
196func (b UpdateBuilder) PrefixExpr(expr Sqlizer) UpdateBuilder {
197	return builder.Append(b, "Prefixes", expr).(UpdateBuilder)
198}
199
200// Table sets the table to be updated.
201func (b UpdateBuilder) Table(table string) UpdateBuilder {
202	return builder.Set(b, "Table", table).(UpdateBuilder)
203}
204
205// Set adds SET clauses to the query.
206func (b UpdateBuilder) Set(column string, value interface{}) UpdateBuilder {
207	return builder.Append(b, "SetClauses", setClause{column: column, value: value}).(UpdateBuilder)
208}
209
210// SetMap is a convenience method which calls .Set for each key/value pair in clauses.
211func (b UpdateBuilder) SetMap(clauses map[string]interface{}) UpdateBuilder {
212	keys := make([]string, len(clauses))
213	i := 0
214	for key := range clauses {
215		keys[i] = key
216		i++
217	}
218	sort.Strings(keys)
219	for _, key := range keys {
220		val, _ := clauses[key]
221		b = b.Set(key, val)
222	}
223	return b
224}
225
226// Where adds WHERE expressions to the query.
227//
228// See SelectBuilder.Where for more information.
229func (b UpdateBuilder) Where(pred interface{}, args ...interface{}) UpdateBuilder {
230	return builder.Append(b, "WhereParts", newWherePart(pred, args...)).(UpdateBuilder)
231}
232
233// OrderBy adds ORDER BY expressions to the query.
234func (b UpdateBuilder) OrderBy(orderBys ...string) UpdateBuilder {
235	return builder.Extend(b, "OrderBys", orderBys).(UpdateBuilder)
236}
237
238// Limit sets a LIMIT clause on the query.
239func (b UpdateBuilder) Limit(limit uint64) UpdateBuilder {
240	return builder.Set(b, "Limit", fmt.Sprintf("%d", limit)).(UpdateBuilder)
241}
242
243// Offset sets a OFFSET clause on the query.
244func (b UpdateBuilder) Offset(offset uint64) UpdateBuilder {
245	return builder.Set(b, "Offset", fmt.Sprintf("%d", offset)).(UpdateBuilder)
246}
247
248// Suffix adds an expression to the end of the query
249func (b UpdateBuilder) Suffix(sql string, args ...interface{}) UpdateBuilder {
250	return b.SuffixExpr(Expr(sql, args...))
251}
252
253// SuffixExpr adds an expression to the end of the query
254func (b UpdateBuilder) SuffixExpr(expr Sqlizer) UpdateBuilder {
255	return builder.Append(b, "Suffixes", expr).(UpdateBuilder)
256}
257