1// Copyright 2018 Huan Du. All rights reserved.
2// Licensed under the MIT license that can be found in the LICENSE file.
3
4package sqlbuilder
5
6import (
7	"bytes"
8	"fmt"
9	"strconv"
10	"strings"
11)
12
13const (
14	updateMarkerInit injectionMarker = iota
15	updateMarkerAfterUpdate
16	updateMarkerAfterSet
17	updateMarkerAfterWhere
18	updateMarkerAfterOrderBy
19	updateMarkerAfterLimit
20)
21
22// NewUpdateBuilder creates a new UPDATE builder.
23func NewUpdateBuilder() *UpdateBuilder {
24	return DefaultFlavor.NewUpdateBuilder()
25}
26
27func newUpdateBuilder() *UpdateBuilder {
28	args := &Args{}
29	return &UpdateBuilder{
30		Cond: Cond{
31			Args: args,
32		},
33		limit:     -1,
34		args:      args,
35		injection: newInjection(),
36	}
37}
38
39// UpdateBuilder is a builder to build UPDATE.
40type UpdateBuilder struct {
41	Cond
42
43	table       string
44	assignments []string
45	whereExprs  []string
46	orderByCols []string
47	order       string
48	limit       int
49
50	args *Args
51
52	injection *injection
53	marker    injectionMarker
54}
55
56var _ Builder = new(UpdateBuilder)
57
58// Update sets table name in UPDATE.
59func Update(table string) *UpdateBuilder {
60	return DefaultFlavor.NewUpdateBuilder().Update(table)
61}
62
63// Update sets table name in UPDATE.
64func (ub *UpdateBuilder) Update(table string) *UpdateBuilder {
65	ub.table = Escape(table)
66	ub.marker = updateMarkerAfterUpdate
67	return ub
68}
69
70// Set sets the assignements in SET.
71func (ub *UpdateBuilder) Set(assignment ...string) *UpdateBuilder {
72	ub.assignments = assignment
73	ub.marker = updateMarkerAfterSet
74	return ub
75}
76
77// SetMore appends the assignements in SET.
78func (ub *UpdateBuilder) SetMore(assignment ...string) *UpdateBuilder {
79	ub.assignments = append(ub.assignments, assignment...)
80	ub.marker = updateMarkerAfterSet
81	return ub
82}
83
84// Where sets expressions of WHERE in UPDATE.
85func (ub *UpdateBuilder) Where(andExpr ...string) *UpdateBuilder {
86	ub.whereExprs = append(ub.whereExprs, andExpr...)
87	ub.marker = updateMarkerAfterWhere
88	return ub
89}
90
91// Assign represents SET "field = value" in UPDATE.
92func (ub *UpdateBuilder) Assign(field string, value interface{}) string {
93	return fmt.Sprintf("%s = %s", Escape(field), ub.args.Add(value))
94}
95
96// Incr represents SET "field = field + 1" in UPDATE.
97func (ub *UpdateBuilder) Incr(field string) string {
98	f := Escape(field)
99	return fmt.Sprintf("%s = %s + 1", f, f)
100}
101
102// Decr represents SET "field = field - 1" in UPDATE.
103func (ub *UpdateBuilder) Decr(field string) string {
104	f := Escape(field)
105	return fmt.Sprintf("%s = %s - 1", f, f)
106}
107
108// Add represents SET "field = field + value" in UPDATE.
109func (ub *UpdateBuilder) Add(field string, value interface{}) string {
110	f := Escape(field)
111	return fmt.Sprintf("%s = %s + %s", f, f, ub.args.Add(value))
112}
113
114// Sub represents SET "field = field - value" in UPDATE.
115func (ub *UpdateBuilder) Sub(field string, value interface{}) string {
116	f := Escape(field)
117	return fmt.Sprintf("%s = %s - %s", f, f, ub.args.Add(value))
118}
119
120// Mul represents SET "field = field * value" in UPDATE.
121func (ub *UpdateBuilder) Mul(field string, value interface{}) string {
122	f := Escape(field)
123	return fmt.Sprintf("%s = %s * %s", f, f, ub.args.Add(value))
124}
125
126// Div represents SET "field = field / value" in UPDATE.
127func (ub *UpdateBuilder) Div(field string, value interface{}) string {
128	f := Escape(field)
129	return fmt.Sprintf("%s = %s / %s", f, f, ub.args.Add(value))
130}
131
132// OrderBy sets columns of ORDER BY in UPDATE.
133func (ub *UpdateBuilder) OrderBy(col ...string) *UpdateBuilder {
134	ub.orderByCols = col
135	ub.marker = updateMarkerAfterOrderBy
136	return ub
137}
138
139// Asc sets order of ORDER BY to ASC.
140func (ub *UpdateBuilder) Asc() *UpdateBuilder {
141	ub.order = "ASC"
142	ub.marker = updateMarkerAfterOrderBy
143	return ub
144}
145
146// Desc sets order of ORDER BY to DESC.
147func (ub *UpdateBuilder) Desc() *UpdateBuilder {
148	ub.order = "DESC"
149	ub.marker = updateMarkerAfterOrderBy
150	return ub
151}
152
153// Limit sets the LIMIT in UPDATE.
154func (ub *UpdateBuilder) Limit(limit int) *UpdateBuilder {
155	ub.limit = limit
156	ub.marker = updateMarkerAfterLimit
157	return ub
158}
159
160// String returns the compiled UPDATE string.
161func (ub *UpdateBuilder) String() string {
162	s, _ := ub.Build()
163	return s
164}
165
166// Build returns compiled UPDATE string and args.
167// They can be used in `DB#Query` of package `database/sql` directly.
168func (ub *UpdateBuilder) Build() (sql string, args []interface{}) {
169	return ub.BuildWithFlavor(ub.args.Flavor)
170}
171
172// BuildWithFlavor returns compiled UPDATE string and args with flavor and initial args.
173// They can be used in `DB#Query` of package `database/sql` directly.
174func (ub *UpdateBuilder) BuildWithFlavor(flavor Flavor, initialArg ...interface{}) (sql string, args []interface{}) {
175	buf := &bytes.Buffer{}
176	ub.injection.WriteTo(buf, updateMarkerInit)
177	buf.WriteString("UPDATE ")
178	buf.WriteString(ub.table)
179	ub.injection.WriteTo(buf, updateMarkerAfterUpdate)
180
181	buf.WriteString(" SET ")
182	buf.WriteString(strings.Join(ub.assignments, ", "))
183	ub.injection.WriteTo(buf, updateMarkerAfterSet)
184
185	if len(ub.whereExprs) > 0 {
186		buf.WriteString(" WHERE ")
187		buf.WriteString(strings.Join(ub.whereExprs, " AND "))
188		ub.injection.WriteTo(buf, updateMarkerAfterWhere)
189	}
190
191	if len(ub.orderByCols) > 0 {
192		buf.WriteString(" ORDER BY ")
193		buf.WriteString(strings.Join(ub.orderByCols, ", "))
194
195		if ub.order != "" {
196			buf.WriteRune(' ')
197			buf.WriteString(ub.order)
198		}
199
200		ub.injection.WriteTo(buf, updateMarkerAfterOrderBy)
201	}
202
203	if ub.limit >= 0 {
204		buf.WriteString(" LIMIT ")
205		buf.WriteString(strconv.Itoa(ub.limit))
206
207		ub.injection.WriteTo(buf, updateMarkerAfterLimit)
208	}
209
210	return ub.args.CompileWithFlavor(buf.String(), flavor, initialArg...)
211}
212
213// SetFlavor sets the flavor of compiled sql.
214func (ub *UpdateBuilder) SetFlavor(flavor Flavor) (old Flavor) {
215	old = ub.args.Flavor
216	ub.args.Flavor = flavor
217	return
218}
219
220// SQL adds an arbitrary sql to current position.
221func (ub *UpdateBuilder) SQL(sql string) *UpdateBuilder {
222	ub.injection.SQL(ub.marker, sql)
223	return ub
224}
225