1package squirrel
2
3import (
4	"bytes"
5	"database/sql"
6	"fmt"
7	"sort"
8	"strings"
9
10	"github.com/lann/builder"
11)
12
13type updateData struct {
14	PlaceholderFormat PlaceholderFormat
15	RunWith           BaseRunner
16	Prefixes          exprs
17	Table             string
18	SetClauses        []setClause
19	WhereParts        []Sqlizer
20	OrderBys          []string
21	Limit             string
22	Offset            string
23	Suffixes          exprs
24}
25
26type setClause struct {
27	column string
28	value  interface{}
29}
30
31func (d *updateData) Exec() (sql.Result, error) {
32	if d.RunWith == nil {
33		return nil, RunnerNotSet
34	}
35	return ExecWith(d.RunWith, d)
36}
37
38func (d *updateData) Query() (*sql.Rows, error) {
39	if d.RunWith == nil {
40		return nil, RunnerNotSet
41	}
42	return QueryWith(d.RunWith, d)
43}
44
45func (d *updateData) QueryRow() RowScanner {
46	if d.RunWith == nil {
47		return &Row{err: RunnerNotSet}
48	}
49	queryRower, ok := d.RunWith.(QueryRower)
50	if !ok {
51		return &Row{err: RunnerNotQueryRunner}
52	}
53	return QueryRowWith(queryRower, d)
54}
55
56func (d *updateData) ToSql() (sqlStr string, args []interface{}, err error) {
57	if len(d.Table) == 0 {
58		err = fmt.Errorf("update statements must specify a table")
59		return
60	}
61	if len(d.SetClauses) == 0 {
62		err = fmt.Errorf("update statements must have at least one Set clause")
63		return
64	}
65
66	sql := &bytes.Buffer{}
67
68	if len(d.Prefixes) > 0 {
69		args, _ = d.Prefixes.AppendToSql(sql, " ", args)
70		sql.WriteString(" ")
71	}
72
73	sql.WriteString("UPDATE ")
74	sql.WriteString(d.Table)
75
76	sql.WriteString(" SET ")
77	setSqls := make([]string, len(d.SetClauses))
78	for i, setClause := range d.SetClauses {
79		var valSql string
80		e, isExpr := setClause.value.(expr)
81		if isExpr {
82			valSql = e.sql
83			args = append(args, e.args...)
84		} else {
85			valSql = "?"
86			args = append(args, setClause.value)
87		}
88		setSqls[i] = fmt.Sprintf("%s = %s", setClause.column, valSql)
89	}
90	sql.WriteString(strings.Join(setSqls, ", "))
91
92	if len(d.WhereParts) > 0 {
93		sql.WriteString(" WHERE ")
94		args, err = appendToSql(d.WhereParts, sql, " AND ", args)
95		if err != nil {
96			return
97		}
98	}
99
100	if len(d.OrderBys) > 0 {
101		sql.WriteString(" ORDER BY ")
102		sql.WriteString(strings.Join(d.OrderBys, ", "))
103	}
104
105	if len(d.Limit) > 0 {
106		sql.WriteString(" LIMIT ")
107		sql.WriteString(d.Limit)
108	}
109
110	if len(d.Offset) > 0 {
111		sql.WriteString(" OFFSET ")
112		sql.WriteString(d.Offset)
113	}
114
115	if len(d.Suffixes) > 0 {
116		sql.WriteString(" ")
117		args, _ = d.Suffixes.AppendToSql(sql, " ", args)
118	}
119
120	sqlStr, err = d.PlaceholderFormat.ReplacePlaceholders(sql.String())
121	return
122}
123
124// Builder
125
126// UpdateBuilder builds SQL UPDATE statements.
127type UpdateBuilder builder.Builder
128
129func init() {
130	builder.Register(UpdateBuilder{}, updateData{})
131}
132
133// Format methods
134
135// PlaceholderFormat sets PlaceholderFormat (e.g. Question or Dollar) for the
136// query.
137func (b UpdateBuilder) PlaceholderFormat(f PlaceholderFormat) UpdateBuilder {
138	return builder.Set(b, "PlaceholderFormat", f).(UpdateBuilder)
139}
140
141// Runner methods
142
143// RunWith sets a Runner (like database/sql.DB) to be used with e.g. Exec.
144func (b UpdateBuilder) RunWith(runner BaseRunner) UpdateBuilder {
145	return setRunWith(b, runner).(UpdateBuilder)
146}
147
148// Exec builds and Execs the query with the Runner set by RunWith.
149func (b UpdateBuilder) Exec() (sql.Result, error) {
150	data := builder.GetStruct(b).(updateData)
151	return data.Exec()
152}
153
154func (b UpdateBuilder) Query() (*sql.Rows, error) {
155	data := builder.GetStruct(b).(updateData)
156	return data.Query()
157}
158
159func (b UpdateBuilder) QueryRow() RowScanner {
160	data := builder.GetStruct(b).(updateData)
161	return data.QueryRow()
162}
163
164func (b UpdateBuilder) Scan(dest ...interface{}) error {
165	return b.QueryRow().Scan(dest...)
166}
167
168// SQL methods
169
170// ToSql builds the query into a SQL string and bound args.
171func (b UpdateBuilder) ToSql() (string, []interface{}, error) {
172	data := builder.GetStruct(b).(updateData)
173	return data.ToSql()
174}
175
176// Prefix adds an expression to the beginning of the query
177func (b UpdateBuilder) Prefix(sql string, args ...interface{}) UpdateBuilder {
178	return builder.Append(b, "Prefixes", Expr(sql, args...)).(UpdateBuilder)
179}
180
181// Table sets the table to be updated.
182func (b UpdateBuilder) Table(table string) UpdateBuilder {
183	return builder.Set(b, "Table", table).(UpdateBuilder)
184}
185
186// Set adds SET clauses to the query.
187func (b UpdateBuilder) Set(column string, value interface{}) UpdateBuilder {
188	return builder.Append(b, "SetClauses", setClause{column: column, value: value}).(UpdateBuilder)
189}
190
191// SetMap is a convenience method which calls .Set for each key/value pair in clauses.
192func (b UpdateBuilder) SetMap(clauses map[string]interface{}) UpdateBuilder {
193	keys := make([]string, len(clauses))
194	i := 0
195	for key := range clauses {
196		keys[i] = key
197		i++
198	}
199	sort.Strings(keys)
200	for _, key := range keys {
201		val, _ := clauses[key]
202		b = b.Set(key, val)
203	}
204	return b
205}
206
207// Where adds WHERE expressions to the query.
208//
209// See SelectBuilder.Where for more information.
210func (b UpdateBuilder) Where(pred interface{}, args ...interface{}) UpdateBuilder {
211	return builder.Append(b, "WhereParts", newWherePart(pred, args...)).(UpdateBuilder)
212}
213
214// OrderBy adds ORDER BY expressions to the query.
215func (b UpdateBuilder) OrderBy(orderBys ...string) UpdateBuilder {
216	return builder.Extend(b, "OrderBys", orderBys).(UpdateBuilder)
217}
218
219// Limit sets a LIMIT clause on the query.
220func (b UpdateBuilder) Limit(limit uint64) UpdateBuilder {
221	return builder.Set(b, "Limit", fmt.Sprintf("%d", limit)).(UpdateBuilder)
222}
223
224// Offset sets a OFFSET clause on the query.
225func (b UpdateBuilder) Offset(offset uint64) UpdateBuilder {
226	return builder.Set(b, "Offset", fmt.Sprintf("%d", offset)).(UpdateBuilder)
227}
228
229// Suffix adds an expression to the end of the query
230func (b UpdateBuilder) Suffix(sql string, args ...interface{}) UpdateBuilder {
231	return builder.Append(b, "Suffixes", Expr(sql, args...)).(UpdateBuilder)
232}
233